Skip to content

Commit

Permalink
Fix quantized_linear cpp op schema
Browse files Browse the repository at this point in the history
Summary: The cpp op schema does not match the registered one. Fix that.

Reviewed By: tarun292, cccclai

Differential Revision: D56594373

fbshipit-source-id: cb4853030715245e7a0177c0f193c4558f19584d
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Apr 26, 2024
1 parent 7b3f5c6 commit 44d4bac
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/cadence/ops/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
- arg_meta: null
kernel_name: impl::HiFi::quantized_layer_norm_out

- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
- func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::HiFi::quantized_linear_out
Expand Down
7 changes: 3 additions & 4 deletions examples/cadence/ops/quantized_linear_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ void quantized_linear_out(
const Tensor& src,
const Tensor& weight,
const Tensor& bias,
double src_scale,
int64_t src_zero_point,
double weight_scale,
int64_t weight_zero_point,
const Tensor& weight_zero_point,
const Tensor& out_multiplier,
const Tensor& out_shift,
int64_t out_zero_point,
const exec_aten::optional<Tensor>& offset,
Tensor& out) {
// input comes in shape [leading_dims, in_dim]
// weight comes in shape [out_dim, in_dim]
Expand Down Expand Up @@ -58,7 +57,7 @@ void quantized_linear_out(
in_dim, // vec_offset of p_mat2.
out_dim, // out_offset, i.e., offset of next output element written
1, // out_stride, i.e., stride to go to next output row
-weight_zero_point, // mat1_zero_bias
-weight_zero_point.const_data_ptr<int32_t>()[0], // mat1_zero_bias
-src_zero_point, // mat2_zero_bias
out_multiplier.const_data_ptr<int32_t>(), // out_multiplier
out_shift.const_data_ptr<int32_t>(), // out_shift
Expand Down

0 comments on commit 44d4bac

Please sign in to comment.