Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions backends/xnnpack/runtime/XNNExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span<EValue*> args) {
xnn_status_to_string(status));
}
}
// // Propagate Input Shape and Memory Plan for increased allocation
// Propagate Input Shape and Memory Plan for increased allocation
status = xnn_reshape_runtime(runtime_.get());

ET_CHECK_OR_RETURN_ERROR(
Expand All @@ -136,6 +136,12 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span<EValue*> args) {
"Internal Error: Propagating input shapes failed with code: %s",
xnn_status_to_string(status));

// Resize output tensors.
Error err = resize_outputs(args);
if (err != Error::Ok) {
return err;
}

return Error::Ok;
}

Expand Down Expand Up @@ -188,14 +194,7 @@ ET_NODISCARD Error XNNExecutor::forward(BackendExecutionContext& context) {
}

/**
* Prepares the outputs for ExecuTorch
*
* Resizes the output tensors based on the output shapes returned by
* the xnnpack runtime.
*
* Note: For arg_max pooling, we recast the output index tensor. Since
* XNNPACK gives the index tensor to us as int32, we need to convert it
* back to int64 for ExecuTorch.
* Resizes output tensors to match XNNPACK's computed shapes.
*/
ET_NODISCARD Error XNNExecutor::resize_outputs(Span<EValue*> args) const {
size_t output_idx_start = input_ids_.size();
Expand Down Expand Up @@ -239,6 +238,22 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(Span<EValue*> args) const {
ET_LOG(Error, "Failed to resize output tensor for XNNExecutor");
return err;
}
}

return Error::Ok;
}

/**
* Converts output data types after XNNPACK execution.
*
* For arg_max pooling, XNNPACK outputs int32 index tensors that need
* to be converted to int64 for ExecuTorch.
*/
ET_NODISCARD Error XNNExecutor::convert_outputs(Span<EValue*> args) const {
size_t output_idx_start = input_ids_.size();
for (size_t i = output_idx_start; i < externals_.size(); ++i) {
uint32_t ext_id = externals_[i].id;
Tensor* out_tensor = &args[ext_id]->toTensor();

// Output datatype is int64. However, XNNPACK doesn't support
// int64. This means that the data was put into this tensor
Expand Down
12 changes: 10 additions & 2 deletions backends/xnnpack/runtime/XNNExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,21 @@ class XNNExecutor {
executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext& context);

/**
* Prepares the outputs to be returned by the delegate
* Resizes output tensors to match XNNPACK's computed shapes.
*
* Performs any post processing of outputs like tensor resizing
*/
ET_NODISCARD executorch::runtime::Error resize_outputs(
executorch::runtime::Span<executorch::runtime::EValue*> args) const;

/**
* Converts output data types after XNNPACK execution.
*
* For arg_max pooling, XNNPACK outputs int32 index tensors that need
* to be converted to int64 for ExecuTorch.
*/
ET_NODISCARD executorch::runtime::Error convert_outputs(
executorch::runtime::Span<executorch::runtime::EValue*> args) const;

friend class XNNCompiler;
};

Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/runtime/XNNPACKBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class XnnpackBackend final
return err;
}

// Resize outputs and recast pointers if necessary
err = executor->resize_outputs(args);
// Convert output data types if necessary (e.g., int32 -> int64 for Long)
err = executor->convert_outputs(args);

return err;
}
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/runtime/test_xnnexecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ TEST(XNNExecutorTest, ResizeOutputsWithLongTensorConvertsInt32ToInt64) {
ASSERT_EQ(executor.prepare_args(span), Error::Ok);
executorch::ET_RUNTIME_NAMESPACE::BackendExecutionContext context;
ASSERT_EQ(executor.forward(context), Error::Ok);
ASSERT_EQ(executor.resize_outputs(span), Error::Ok);
ASSERT_EQ(executor.convert_outputs(span), Error::Ok);

Tensor& result = args[2]->toTensor();
ASSERT_EQ(result.scalar_type(), executorch::aten::ScalarType::Long);
Expand Down
Loading