Skip to content

Commit

Permalink
Support negative axis for reverse_v2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 301736084
Change-Id: Id23154b459566d5f85ef0c8d8bc98117c510b963
  • Loading branch information
renjie-liu authored and tensorflower-gardener committed Mar 19, 2020
1 parent ffbdfbb commit 6eca562
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion tensorflow/lite/kernels/reverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
int axis = GetTensorData<int32_t>(axis_tensor)[0];
const int rank = NumDimensions(input);
if (axis < 0) {
axis += rank;
}

TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input));
TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);

switch (output->type) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/testing/op_tests/reverse_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def make_reverse_v2_tests(options):
test_parameters = [{
"dtype": [tf.float32, tf.bool],
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
"axis": [0, 1, 2, 3],
"axis": [-2, -1, 0, 1, 2, 3],
}]

def get_valid_axis(parameters):
Expand Down

0 comments on commit 6eca562

Please sign in to comment.