Skip to content

Commit

Permalink
[QNN] Replace nn.leaky_relu with qnn.leaky_relu (apache#11930)
Browse files Browse the repository at this point in the history
* [QNN] Replace nn.leaky_relu with qnn.leaky_relu

* jostle ci

* fix typo
  • Loading branch information
zhaoyang-star authored Jul 12, 2022
1 parent 993a8ea commit 6d676ba
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,10 +937,9 @@ def _impl(inputs, _):
return _impl


def _leaky_relu():
def _leaky_relu(fp32_piggy_back=False):
# refer to src/ATen/native/quantized/cpu/qrelu.cpp
def _impl(inputs, _):
assert len(inputs) == 7, "Input quant params not found in op inputs"
def _impl_fp32(inputs, _):
alpha = inputs[1]
output_scale = _expr.const(inputs[3])
output_zero_point = _expr.const(inputs[4])
Expand All @@ -952,6 +951,18 @@ def _impl(inputs, _):
dequantized, output_scale, output_zero_point, out_dtype="uint8"
)

def _impl_int8(inputs, _):
alpha = inputs[1]
output_scale = _expr.const(inputs[3])
output_zero_point = _expr.const(inputs[4])
return relay.qnn.op.leaky_relu(inputs[0], alpha, output_scale, output_zero_point)

def _impl(inputs, _):
assert len(inputs) == 7, "Input quant params not found in op inputs"
if fp32_piggy_back:
return _impl_fp32(inputs, _)
return _impl_int8(inputs, _)

return _impl


Expand Down

0 comments on commit 6d676ba

Please sign in to comment.