Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized atan2, _softmax, cat, clamp, full, relu, remainder, permute_copy_out ops and updates to use memory_allocator #7567

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
216389c
Adding mean and where ops optimized on HiFi
dijopaul Oct 23, 2024
3d849bb
Merge pull request #14 from dijopaul/main
cad-audio Oct 24, 2024
9b71aed
Adding quantized linear optimized versions for int8 and uint8
dijopaul Nov 6, 2024
07743ab
adding pow, remainder, minimum, maximum operators (#33)
nishpoonia Nov 7, 2024
edc1b3d
Fix for build issue faced in div_mod on old tools
dijopaul Nov 13, 2024
222beee
Merge pull request #15 from dijopaul/main
cad-audio Nov 14, 2024
6e074ec
Merge branch 'main' into main
cad-audio Nov 14, 2024
afca3db
Fix build failure due to merge issue
dijopaul Nov 19, 2024
10a0ee0
Merge branch 'main' into main
mcremon-meta Nov 21, 2024
f1f0bb3
Fixing review comments on PR 6867
dijopaul Nov 22, 2024
f8cf408
Malloc fix (#39)
dijopaul Nov 28, 2024
911021f
Cleaning cmakelist to avoid duplications
dijopaul Dec 2, 2024
18cf518
Fixing lint issues and removing free statements
dijopaul Dec 3, 2024
5e471f2
adding ET_KERNEL_CHECK for allocate_temp_memory (#41)
nishpoonia Dec 23, 2024
6928f95
Merge branch 'main' into main_PR18
dijopaul Jan 9, 2025
991961b
Fixing lint error due to merge
dijopaul Jan 9, 2025
7585ee0
Merge pull request #18 from dijopaul/main_PR18
cad-audio Jan 9, 2025
540243a
Update functions_hifi.yaml
dijopaul Jan 9, 2025
85e7c59
Merge pull request #19 from dijopaul/patch-1
cad-audio Jan 9, 2025
1f681c7
Incorporating review comments: removing nesting to check data type an…
nishpoonia Jan 10, 2025
3539f52
clean up
nishpoonia Jan 13, 2025
fe5e7d7
Merge pull request #20 from dijopaul/main_PR18
cad-audio Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding quantized linear optimized versions for int8 and uint8
  • Loading branch information
dijopaul committed Nov 6, 2024
commit 9b71aeda0388c73c4c607911c0a7c581f107dc17
90 changes: 57 additions & 33 deletions backends/cadence/hifi/operators/quantized_linear_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,74 @@ using executorch::runtime::getLeadingDims;
using executorch::runtime::KernelRuntimeContext;

void quantized_linear_out(
KernelRuntimeContext& ctx,
__ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& src,
const Tensor& weight,
const Tensor& bias,
int64_t src_zero_point,
const Tensor& weight_zero_point,
const Tensor& weight_zero_point_t,
const Tensor& out_multiplier,
const Tensor& out_shift,
int64_t out_zero_point,
const executorch::aten::optional<Tensor>& offset,
__ET_UNUSED const executorch::aten::optional<Tensor>& offset,
Tensor& out) {
// input comes in shape [leading_dims, in_dim]
// weight comes in shape [out_dim, in_dim]
// output comes in empty with shape [leading_dims, out_dim]
// Perform matrix multiply (M x N) x (N x P)' => M x P
int64_t leading_dims = getLeadingDims(src, src.dim() - 1);
int64_t out_dim = weight.size(0); // = out_dim
int64_t in_dim = weight.size(1); // = in_dim
int64_t out_dim = weight.size(0);
int64_t in_dim = weight.size(1);

const uint8_t* __restrict__ in_data = src.const_data_ptr<uint8_t>();
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
const uint8_t* __restrict__ in_data = src.const_data_ptr<uint8_t>();
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();

// The nnlib kernel to compute quantized linear via matmul.
int32_t ret = cadence::impl::HiFi::kernels::matmul_asym8uxasym8u_asym8u(
out_data, // p_out
weight_data, // p_mat1,
in_data, // p_mat2,
bias_data, // p_bias
out_dim, // rows of p_mat1
in_dim, // cols of p_mat1
in_dim, // row_stride of p_mat1
leading_dims, // vec_count, i.e., rows of p_mat2
in_dim, // vec_offset of p_mat2.
out_dim, // out_offset, i.e., offset of next output element written
1, // out_stride, i.e., stride to go to next output row
-weight_zero_point.const_data_ptr<int32_t>()[0], // mat1_zero_bias
-src_zero_point, // mat2_zero_bias
out_multiplier.const_data_ptr<int32_t>(), // out_multiplier
out_shift.const_data_ptr<int32_t>(), // out_shift
out_zero_point, // out_zero_bias
false); // per channel quantization
ET_DCHECK_MSG(ret == 0, "HiFi quantized::linear failed");
// The nnlib kernel to compute quantized linear via matmul.
xa_nn_matmul_asym8uxasym8u_asym8u(
out_data,
weight_data,
in_data,
bias_data,
out_dim,
in_dim,
in_dim,
leading_dims,
in_dim,
out_dim,
1,
-weight_zero_point_t.const_data_ptr<int32_t>()[0],
-src_zero_point,
out_multiplier.const_data_ptr<int32_t>()[0],
out_shift.const_data_ptr<int32_t>()[0],
out_zero_point);
} else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
const int8_t* __restrict__ in_data = src.const_data_ptr<int8_t>();
const int8_t* __restrict__ weight_data = weight.const_data_ptr<int8_t>();
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
int8_t* __restrict__ out_data = out.mutable_data_ptr<int8_t>();

xa_nn_matmul_asym8sxasym8s_asym8s(
out_data,
weight_data,
in_data,
bias_data,
out_dim,
in_dim,
in_dim,
leading_dims,
in_dim,
out_dim,
1,
-weight_zero_point_t.const_data_ptr<int32_t>()[0],
-src_zero_point,
out_multiplier.const_data_ptr<int32_t>()[0],
out_shift.const_data_ptr<int32_t>()[0],
out_zero_point);
} else {
ET_CHECK_MSG(
false,
"Unhandled input dtype %hhd",
static_cast<int8_t>(src.scalar_type()));
}
}

}; // namespace native
Expand Down