-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX FE] Support lax.argmax operation for JAX (#26671)
### Details: - Support lax.argmax for JAX and create relevant layer test - 2 util improvements - Fix `num_inputs_check` not checking max inputs - Better error message when param name not exist ### Tickets: - #26574 --------- Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
- Loading branch information
1 parent
c5025cc
commit 9a02e54
Showing
5 changed files
with
108 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright (C) 2018-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "openvino/frontend/jax/node_context.hpp" | ||
#include "openvino/op/constant.hpp" | ||
#include "openvino/op/squeeze.hpp" | ||
#include "openvino/op/topk.hpp" | ||
#include "utils.hpp" | ||
|
||
namespace ov { | ||
namespace frontend { | ||
namespace jax { | ||
namespace op { | ||
|
||
using namespace ov::op; | ||
|
||
OutputVector translate_argmax(const NodeContext& context) { | ||
num_inputs_check(context, 1, 1); | ||
Output<Node> input = context.get_input(0); | ||
auto axis_val = context.const_named_param<int64_t>("axes"); | ||
auto axis = context.const_named_param<std::shared_ptr<v0::Constant>>("axes"); | ||
auto dtype = convert_dtype(context.const_named_param<int64_t>("index_dtype")); | ||
|
||
auto k = std::make_shared<v0::Constant>(element::i64, Shape{}, 1); | ||
auto topk = std::make_shared<v11::TopK>(input, | ||
k, | ||
axis_val, | ||
v11::TopK::Mode::MAX, | ||
v1::TopK::SortType::SORT_VALUES, | ||
dtype, | ||
true); | ||
auto indices = topk->output(1); | ||
|
||
auto res = std::make_shared<v0::Squeeze>(indices, axis); | ||
return {res}; | ||
}; | ||
|
||
} // namespace op | ||
} // namespace jax | ||
} // namespace frontend | ||
} // namespace ov |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (C) 2018-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import numpy as np | ||
import pytest | ||
from jax import lax | ||
from jax import numpy as jnp | ||
|
||
from jax_layer_test_class import JaxLayerTest | ||
|
||
rng = np.random.default_rng(706670) | ||
|
||
|
||
class TestArgmax(JaxLayerTest): | ||
def _prepare_input(self): | ||
if np.issubdtype(self.input_type, np.floating): | ||
x = rng.uniform(-5.0, 5.0, | ||
self.input_shape).astype(self.input_type) | ||
elif np.issubdtype(self.input_type, np.signedinteger): | ||
x = rng.integers(-8, 8, self.input_shape).astype(self.input_type) | ||
else: | ||
x = rng.integers(0, 8, self.input_shape).astype(self.input_type) | ||
|
||
if self.input_duplicate: | ||
x = np.concatenate((x, x), axis=self.axis) | ||
|
||
x = jnp.array(x) | ||
return [x] | ||
|
||
def create_model(self, input_shape, axis, input_type, index_dtype, input_duplicate): | ||
self.input_shape = input_shape | ||
self.axis = axis | ||
self.input_type = input_type | ||
self.input_duplicate = input_duplicate | ||
|
||
def jax_argmax(inp): | ||
out = lax.argmax(inp, axis, index_dtype) | ||
return out | ||
|
||
return jax_argmax, None, 'argmax' | ||
|
||
# Only [0, rank - 1] are valid axes for lax.argmax | ||
@pytest.mark.parametrize('input_shape, axis', [([64], 0), | ||
([64, 16], 0), | ||
([64, 16], 1), | ||
([48, 23, 54], 0), | ||
([48, 23, 54], 1), | ||
([48, 23, 54], 2), | ||
([2, 18, 32, 25], 0), | ||
([2, 18, 32, 25], 1), | ||
([2, 18, 32, 25], 2), | ||
([2, 18, 32, 25], 3)]) | ||
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16, | ||
np.int32, np.uint32, np.int64, np.uint64, | ||
np.float16, np.float32, np.float64]) | ||
@pytest.mark.parametrize("index_dtype", [np.int32, np.int64]) | ||
@pytest.mark.parametrize("input_duplicate", [False, True]) | ||
@pytest.mark.nightly | ||
@pytest.mark.precommit_jax_fe | ||
def test_argmax(self, ie_device, precision, ir_version, input_shape, axis, input_type, index_dtype, input_duplicate): | ||
self._test(*self.create_model(input_shape, axis, input_type, index_dtype, input_duplicate), | ||
ie_device, precision, ir_version) |