Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 148 additions & 46 deletions ynnpack/xnnpack/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,13 @@ ynn_status convert_to(ynn_subgraph_t subgraph, uint32_t* value_id,
return status;
}

ynn_status define_convert_dot_inputs(ynn_subgraph_t subgraph,
uint32_t input_a_id, uint32_t* input_b_id,
uint32_t* bias_id) {
if (type_of_value(subgraph, *input_b_id) == ynn_type_uint8) {
const ynn_value& b = subgraph->value(*input_b_id);
ynn_status define_xnn_dot_float(ynn_subgraph_t subgraph, size_t num_k_dims,
uint32_t a_id, uint32_t b_id, uint32_t bias_id,
uint32_t output_id) {
ynn_type a_type = type_of_value(subgraph, a_id);

if (type_of_value(subgraph, b_id) == ynn_type_uint8) {
const ynn_value& b = subgraph->value(b_id);
ynn_status status;
// Convert uint8 to int8
uint32_t zero_point_id = YNN_INVALID_VALUE_ID;
Expand All @@ -439,62 +441,123 @@ ynn_status define_convert_dot_inputs(ynn_subgraph_t subgraph,
return status;
}

status = ynn_define_unary(subgraph, ynn_unary_convert, *input_b_id,
status = ynn_define_unary(subgraph, ynn_unary_convert, b_id,
&b_int8_id, /*flags=*/0);
if (status != ynn_status_success) {
return status;
}
*input_b_id = b_int8_id;
} else if (!type_is_integral(type_of_value(subgraph, input_a_id))) {
// XNNPACK allows a mix of fp16 and fp32 inputs, and it always converts the
// weights and bias to the same type as the input.
ynn_type a_type = type_of_value(subgraph, input_a_id);
b_id = b_int8_id;
} else {
// TODO(dsharlet): XNNPACK also supports fp input, quantized weights, but
// that support is questionably correct/useful, so leaving it for later.
assert(!type_is_integral(type_of_value(subgraph, *input_b_id)));
ynn_status status = convert_to(subgraph, input_b_id, a_type);
assert(!type_is_integral(type_of_value(subgraph, b_id)));
ynn_status status = convert_to(subgraph, &b_id, a_type);
if (status != ynn_status_success) {
return status;
}
}

uint32_t bias_converted_id = bias_id;
if (bias_converted_id != YNN_INVALID_VALUE_ID) {
// We need biases to be fp32, so we can initialize the accumulators, which
// are always fp32 for floating point inputs.
status = convert_to(subgraph, bias_id, ynn_type_fp32);
ynn_status status =
convert_to(subgraph, &bias_converted_id, ynn_type_fp32);
if (status != ynn_status_success) {
return status;
}
}
return ynn_status_success;
}

} // namespace

ynn_type accumulator_for_type(ynn_type type) {
if (type_promotes_to_float(type)) {
return ynn_type_fp32;
} else {
return ynn_type_int32;
uint32_t init_accumulator_id = YNN_INVALID_VALUE_ID;
uint32_t accumulator_id = output_id;
const bool allow_reuse = (bias_converted_id == YNN_INVALID_VALUE_ID);
ynn_status status = define_xnn_accumulator_for_dot(
subgraph, num_k_dims, a_id, b_id, &init_accumulator_id, &accumulator_id,
allow_reuse);
if (status != ynn_status_success) {
return status;
}
}

ynn_status define_xnn_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
uint32_t a_id, uint32_t b_id, uint32_t bias_id,
uint32_t output_id) {
uint32_t bias_converted_id = bias_id;
ynn_status status =
define_convert_dot_inputs(subgraph, a_id, &b_id, &bias_converted_id);
status = ynn_define_dot(subgraph, num_k_dims, a_id, b_id, init_accumulator_id,
&accumulator_id, /*flags=*/0);
if (status != ynn_status_success) {
return status;
}

uint32_t output_unconverted_id = accumulator_id;
if (bias_converted_id != YNN_INVALID_VALUE_ID) {
uint32_t added_id = YNN_INVALID_VALUE_ID;
status = define_binary_with_broadcasting(
subgraph, ynn_binary_add, accumulator_id, bias_converted_id, &added_id,
/*flags=*/0);
if (status != ynn_status_success) {
return status;
}
output_unconverted_id = added_id;
}

if (output_unconverted_id != output_id) {
if (type_of_value(subgraph, output_id) !=
type_of_value(subgraph, output_unconverted_id)) {
status = ynn_define_unary(subgraph, ynn_unary_convert,
output_unconverted_id, &output_id, /*flags=*/0);
if (status != ynn_status_success) {
return status;
}
} else {
status = ynn_define_copy(subgraph, output_unconverted_id, &output_id,
/*flags=*/0);
if (status != ynn_status_success) {
return status;
}
}
}

return ynn_status_success;
}

ynn_status define_xnn_dot_int(ynn_subgraph_t subgraph, size_t num_k_dims,
uint32_t a_id, uint32_t b_id, uint32_t bias_id,
uint32_t output_id) {
if (type_of_value(subgraph, b_id) == ynn_type_uint8) {
const ynn_value& b = subgraph->value(b_id);
ynn_status status;
// Convert uint8 to int8
uint32_t zero_point_id = YNN_INVALID_VALUE_ID;
if (b.zero_point_id != YNN_INVALID_VALUE_ID) {
status =
ynn::define_binary_scalar_b(subgraph, ynn_binary_subtract,
b.zero_point_id, 128.0f, &zero_point_id);
if (status != ynn_status_success) {
return status;
}
} else {
zero_point_id = subgraph->get_scalar_value_id<int32_t>(-128);
}

uint32_t b_int8_id = YNN_INVALID_VALUE_ID;
status = ynn_define_tensor_value(subgraph, ynn_type_int8, /*rank=*/0,
/*dims=*/nullptr, /*data=*/nullptr,
zero_point_id, b.scale_id, /*flags=*/0,
&b_int8_id);
if (status != ynn_status_success) {
return status;
}

status = ynn_define_unary(subgraph, ynn_unary_convert, b_id,
&b_int8_id, /*flags=*/0);
if (status != ynn_status_success) {
return status;
}
b_id = b_int8_id;
}

uint32_t init_accumulator_id = YNN_INVALID_VALUE_ID;
uint32_t accumulator_id = output_id;
// If we have a bias, the dot product creates an intermediate value to
// accumulate into instead of reusing `output_id`.
const bool allow_reuse = (bias_converted_id == YNN_INVALID_VALUE_ID);
status = define_xnn_accumulator_for_dot(subgraph, num_k_dims, a_id, b_id,
&init_accumulator_id, &accumulator_id,
allow_reuse);
const bool allow_reuse = (bias_id == YNN_INVALID_VALUE_ID);
ynn_status status = define_xnn_accumulator_for_dot(
subgraph, num_k_dims, a_id, b_id, &init_accumulator_id, &accumulator_id,
allow_reuse);
if (status != ynn_status_success) {
return status;
}
Expand All @@ -507,7 +570,6 @@ ynn_status define_xnn_dot(ynn_subgraph_t subgraph, size_t num_k_dims,

// XNNPACK semantics: output = scale * (dot_result) + bias.
// We perform this in fp32 to support arbitrary bias scales and fusing.

uint32_t dot_result_as_float_id = accumulator_id;

// If the accumulator is integral (quantized), we must convert it to float.
Expand All @@ -519,41 +581,81 @@ ynn_status define_xnn_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID,
/*flags=*/0, &float_val_id);
if (status != ynn_status_success) return status;
if (status != ynn_status_success) {
return status;
}

status = ynn_define_unary(subgraph, ynn_unary_convert, accumulator_id,
&float_val_id, /*flags=*/0);
if (status != ynn_status_success) return status;
if (status != ynn_status_success) {
return status;
}

dot_result_as_float_id = float_val_id;
}

uint32_t output_unconverted_id = dot_result_as_float_id;
if (bias_converted_id != YNN_INVALID_VALUE_ID) {
if (bias_id != YNN_INVALID_VALUE_ID) {
uint32_t bias_converted_id = bias_id;
// Ensure bias is also fp32 for the addition.
status = convert_to(subgraph, &bias_converted_id, ynn_type_fp32);
if (status != ynn_status_success) return status;
if (status != ynn_status_success) {
return status;
}

uint32_t added_id = YNN_INVALID_VALUE_ID;
status = define_binary_with_broadcasting(
subgraph, ynn_binary_add, dot_result_as_float_id, bias_converted_id,
&added_id, /*flags=*/0);
if (status != ynn_status_success) return status;
if (status != ynn_status_success) {
return status;
}

output_unconverted_id = added_id;
}

if (output_unconverted_id != output_id) {
status = ynn_define_unary(subgraph, ynn_unary_convert,
output_unconverted_id, &output_id, /*flags=*/0);
if (status != ynn_status_success) {
return status;
if (type_of_value(subgraph, output_id) !=
type_of_value(subgraph, output_unconverted_id)) {
status = ynn_define_unary(subgraph, ynn_unary_convert,
output_unconverted_id, &output_id, /*flags=*/0);
if (status != ynn_status_success) {
return status;
}
} else {
status = ynn_define_copy(subgraph, output_unconverted_id, &output_id,
/*flags=*/0);
if (status != ynn_status_success) {
return status;
}
}
}

return ynn_status_success;
}

} // namespace

ynn_type accumulator_for_type(ynn_type type) {
if (type_promotes_to_float(type)) {
return ynn_type_fp32;
} else {
return ynn_type_int32;
}
}

ynn_status define_xnn_dot(ynn_subgraph_t subgraph, size_t num_k_dims,
uint32_t a_id, uint32_t b_id, uint32_t bias_id,
uint32_t output_id) {
if (type_promotes_to_float(type_of_value(subgraph, a_id))) {
return define_xnn_dot_float(subgraph, num_k_dims, a_id, b_id, bias_id,
output_id);
} else {
return define_xnn_dot_int(subgraph, num_k_dims, a_id, b_id, bias_id,
output_id);
}
}

ynn_status define_binary_scalar_a(ynn_subgraph_t subgraph,
ynn_binary_operator op, float scalar_a,
uint32_t input_b_id, uint32_t* output_id) {
Expand Down
Loading