Skip to content

Commit

Permalink
[JAX FE] Support lax.argmax operation for JAX (#26671)
Browse files Browse the repository at this point in the history
### Details:
 - Support lax.argmax for JAX and create relevant layer test
 - 2 util improvements
   - Fix `num_inputs_check` not checking max inputs 
   - Better error message when param name not exist

### Tickets:
 - #26574

---------

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
  • Loading branch information
halm-zenger and rkazants authored Oct 20, 2024
1 parent c5025cc commit 9a02e54
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class NodeContext : public frontend::NodeContext {
}

Output<Node> get_param(const std::string& name) const {
FRONT_END_GENERAL_CHECK(m_param_name_to_id.count(name), "No param id corresponding name exists: ", name);
auto id = m_param_name_to_id.at(name);
FRONT_END_GENERAL_CHECK(m_tensor_map->count(id), "No tensor corresponding param id: ", id, " exist.");
return m_tensor_map->at(id);
Expand Down
42 changes: 42 additions & 0 deletions src/frontends/jax/src/op/argmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/frontend/jax/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/topk.hpp"
#include "utils.hpp"

namespace ov {
namespace frontend {
namespace jax {
namespace op {

using namespace ov::op;

OutputVector translate_argmax(const NodeContext& context) {
num_inputs_check(context, 1, 1);
Output<Node> input = context.get_input(0);
auto axis_val = context.const_named_param<int64_t>("axes");
auto axis = context.const_named_param<std::shared_ptr<v0::Constant>>("axes");
auto dtype = convert_dtype(context.const_named_param<int64_t>("index_dtype"));

auto k = std::make_shared<v0::Constant>(element::i64, Shape{}, 1);
auto topk = std::make_shared<v11::TopK>(input,
k,
axis_val,
v11::TopK::Mode::MAX,
v1::TopK::SortType::SORT_VALUES,
dtype,
true);
auto indices = topk->output(1);

auto res = std::make_shared<v0::Squeeze>(indices, axis);
return {res};
};

} // namespace op
} // namespace jax
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/jax/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace op {
template <class T> \
OutputVector op(const ov::frontend::jax::NodeContext& node)

OP_CONVERTER(translate_argmax);
OP_T_CONVERTER(translate_binary_op);
OP_CONVERTER(translate_broadcast_in_dim);
OP_CONVERTER(translate_concatenate);
Expand All @@ -59,6 +60,7 @@ OP_CONVERTER(translate_transpose);
// Supported ops for Jaxpr
const std::map<std::string, CreatorFunction> get_supported_ops_jaxpr() {
return {{"add", op::translate_1to1_match_2_inputs<v1::Add>},
{"argmax", op::translate_argmax},
{"broadcast_in_dim", op::translate_broadcast_in_dim},
{"concatenate", op::translate_concatenate},
{"constant", op::translate_constant},
Expand Down
1 change: 1 addition & 0 deletions src/frontends/jax/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace jax {
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
auto inputs = context.inputs();
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= min_inputs, "Got less inputs than expected");
FRONT_END_OP_CONVERSION_CHECK(inputs.size() <= max_inputs, "Got more inputs than expected");
}

void num_inputs_check(const NodeContext& context, size_t min_inputs) {
Expand Down
62 changes: 62 additions & 0 deletions tests/layer_tests/jax_tests/test_argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
from jax import lax
from jax import numpy as jnp

from jax_layer_test_class import JaxLayerTest

rng = np.random.default_rng(706670)


class TestArgmax(JaxLayerTest):
def _prepare_input(self):
if np.issubdtype(self.input_type, np.floating):
x = rng.uniform(-5.0, 5.0,
self.input_shape).astype(self.input_type)
elif np.issubdtype(self.input_type, np.signedinteger):
x = rng.integers(-8, 8, self.input_shape).astype(self.input_type)
else:
x = rng.integers(0, 8, self.input_shape).astype(self.input_type)

if self.input_duplicate:
x = np.concatenate((x, x), axis=self.axis)

x = jnp.array(x)
return [x]

def create_model(self, input_shape, axis, input_type, index_dtype, input_duplicate):
self.input_shape = input_shape
self.axis = axis
self.input_type = input_type
self.input_duplicate = input_duplicate

def jax_argmax(inp):
out = lax.argmax(inp, axis, index_dtype)
return out

return jax_argmax, None, 'argmax'

# Only [0, rank - 1] are valid axes for lax.argmax
@pytest.mark.parametrize('input_shape, axis', [([64], 0),
([64, 16], 0),
([64, 16], 1),
([48, 23, 54], 0),
([48, 23, 54], 1),
([48, 23, 54], 2),
([2, 18, 32, 25], 0),
([2, 18, 32, 25], 1),
([2, 18, 32, 25], 2),
([2, 18, 32, 25], 3)])
@pytest.mark.parametrize('input_type', [np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64,
np.float16, np.float32, np.float64])
@pytest.mark.parametrize("index_dtype", [np.int32, np.int64])
@pytest.mark.parametrize("input_duplicate", [False, True])
@pytest.mark.nightly
@pytest.mark.precommit_jax_fe
def test_argmax(self, ie_device, precision, ir_version, input_shape, axis, input_type, index_dtype, input_duplicate):
self._test(*self.create_model(input_shape, axis, input_type, index_dtype, input_duplicate),
ie_device, precision, ir_version)

0 comments on commit 9a02e54

Please sign in to comment.