Skip to content

Commit 8f32013

Browse files
Mistobaanrmlarsen
authored andcommitted
add Rint operation (tensorflow#4113)
* add rint op * wip on custom rounding op * fix Makefile build for the rint op * rebase, add test, fix examples
1 parent d1060ca commit 8f32013

File tree

12 files changed

+202
-5
lines changed

12 files changed

+202
-5
lines changed

tensorflow/contrib/makefile/proto_text_cc_files.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ tensorflow/core/platform/posix/env.cc
1111
tensorflow/core/platform/posix/load_library.cc
1212
tensorflow/core/platform/file_system.cc
1313
tensorflow/core/platform/env.cc
14+
tensorflow/core/platform/setround.cc
1415
tensorflow/core/platform/denormal.cc
1516
tensorflow/core/platform/default/tracing.cc
1617
tensorflow/core/platform/default/logging.cc
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#if GOOGLE_CUDA
17+
18+
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
19+
20+
namespace tensorflow {
21+
namespace functor {
22+
DEFINE_UNARY2(rint, float, double);
23+
} // namespace functor
24+
} // namespace tensorflow
25+
26+
#endif // GOOGLE_CUDA
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright 2016 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/kernels/cwise_ops_common.h"
17+
18+
namespace tensorflow {
19+
REGISTER2(UnaryOp, CPU, "Rint", functor::rint, float, double);
20+
#if GOOGLE_CUDA
21+
REGISTER2(UnaryOp, GPU, "Rint", functor::rint, float, double);
22+
#endif
23+
} // namespace tensorflow

tensorflow/core/kernels/cwise_ops.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,26 @@ struct round : base<T, Eigen::internal::scalar_round_op_google<T>> {};
521521
template <typename T>
522522
struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {};
523523

524+
/** this should go in Eigen
525+
* \brief Template functor to compute the round to int value of a scalar
526+
*/
527+
template<typename Scalar> struct scalar_rint_op {
528+
EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
529+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \
530+
const Scalar operator() (const Scalar& a) const {
531+
#if defined(__CUDACC__)
532+
return ::rint(a);
533+
#elif defined(PLATFORM_POSIX_ANDROID)
534+
return rint(a);
535+
#else
536+
return std::rint(a);
537+
#endif
538+
}
539+
};
540+
541+
template <typename T>
542+
struct rint : base<T, scalar_rint_op<T>> {};
543+
524544
////////////////////////////////////////////////////////////////////////////////
525545
// Binary functors
526546
////////////////////////////////////////////////////////////////////////////////

tensorflow/core/kernels/cwise_ops_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ BM_UNARY(gpu, Conj, std::complex<float>, DT_COMPLEX64);
5959
BM_UNARY(cpu, Conj, std::complex<double>, DT_COMPLEX128);
6060
BM_UNARY(gpu, Conj, std::complex<double>, DT_COMPLEX128);
6161

62+
BM_UNARY(cpu, Rint, double, DT_DOUBLE);
63+
BM_UNARY(gpu, Rint, double, DT_DOUBLE);
64+
BM_UNARY(cpu, Rint, float, DT_FLOAT);
65+
BM_UNARY(gpu, Rint, float, DT_FLOAT);
66+
6267
// data func scalar.
6368
static Graph* BinaryScalar(int num, const string& func) {
6469
Graph* g = new Graph(OpRegistry::Global());

tensorflow/core/lib/core/threadpool.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ limitations under the License.
2121
#include "tensorflow/core/platform/denormal.h"
2222
#include "tensorflow/core/platform/logging.h"
2323
#include "tensorflow/core/platform/mutex.h"
24+
#include "tensorflow/core/platform/setround.h"
2425
#include "tensorflow/core/platform/tracing.h"
2526
#include "tensorflow/core/platform/types.h"
2627

28+
2729
namespace tensorflow {
2830
namespace thread {
2931

@@ -50,6 +52,8 @@ struct EigenEnvironment {
5052
return env_->StartThread(thread_options_, name_, [=]() {
5153
// Set the processor flag to flush denormals to zero
5254
port::ScopedFlushDenormal flush;
55+
// Set the C++ rounding mode to ROUND TO NEAREST
56+
port::ScopedSetRound round;
5357
f();
5458
});
5559
}

tensorflow/core/ops/math_ops.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,25 @@ REGISTER_OP("Ceil")
454454
Returns element-wise smallest integer in not less than x.
455455
)doc");
456456

457+
REGISTER_OP("Rint")
458+
.Input("x: T")
459+
.Output("y: T")
460+
.Attr("T: {float, double}")
461+
.SetShapeFn(shape_inference::UnchangedShape)
462+
.Doc(R"doc(
463+
Returns element-wise integer closest to x.
464+
465+
If the result is midway between two representable values,
466+
the even representable is chosen.
467+
For example:
468+
469+
```
470+
rint(-1.5) ==> -2.0
471+
rint(0.5000001) ==> 1.0
472+
rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
473+
```
474+
)doc");
475+
457476
// Declares cwise binary operations signature: 't, 't -> 't.
458477

459478
#define BINARY_MORE() \

tensorflow/core/platform/setround.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/platform/setround.h"
17+
18+
#ifdef __STDC_IEC_559__
19+
#include <fenv.h> // fesetround, FE_*
20+
#endif
21+
22+
namespace tensorflow {
23+
namespace port {
24+
25+
ScopedSetRound::ScopedSetRound() {
26+
#ifdef __STDC_IEC_559__
27+
std::fesetround(FE_TONEAREST);
28+
#endif
29+
}
30+
31+
ScopedSetRound::~ScopedSetRound() {
32+
}
33+
34+
} // namespace port
35+
} // namespace tensorflow

tensorflow/core/platform/setround.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_PLATFORM_SETROUND_H_
17+
#define TENSORFLOW_PLATFORM_SETROUND_H_
18+
19+
#include "tensorflow/core/platform/macros.h"
20+
21+
namespace tensorflow {
22+
namespace port {
23+
24+
// While this class is active, floating point numbers are rounded to NEAREST
25+
// to zero. The destructor restores the original flags.
26+
class ScopedSetRound {
27+
public:
28+
ScopedSetRound();
29+
~ScopedSetRound();
30+
31+
private:
32+
TF_DISALLOW_COPY_AND_ASSIGN(ScopedSetRound);
33+
};
34+
35+
} // namespace port
36+
} // namespace tensorflow
37+
38+
#endif // TENSORFLOW_PLATFORM_SETROUN_H_

tensorflow/python/kernel_tests/cwise_ops_test.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1778,9 +1778,17 @@ def testSqrt(self):
17781778

17791779
class RoundingTest(tf.test.TestCase):
17801780

1781-
def _compare(self, x, use_gpu):
1781+
def _compare_values(self, x, y=None):
1782+
y = np.rint(x) if y is None else np.asarray(y)
1783+
with self.test_session() as sess:
1784+
tf_rint = tf.rint(x)
1785+
np_rint = sess.run(tf_rint)
1786+
self.assertAllEqual(y, np_rint)
1787+
self.assertShapeEqual(y, tf_rint)
1788+
1789+
def _compare(self, x):
17821790
np_floor, np_ceil = np.floor(x), np.ceil(x)
1783-
with self.test_session(use_gpu=use_gpu) as sess:
1791+
with self.test_session() as sess:
17841792
inx = tf.convert_to_tensor(x)
17851793
ofloor, oceil = tf.floor(inx), tf.ceil(inx)
17861794
tf_floor, tf_ceil = sess.run([ofloor, oceil])
@@ -1790,9 +1798,20 @@ def _compare(self, x, use_gpu):
17901798
self.assertShapeEqual(np_ceil, oceil)
17911799

17921800
def _testDtype(self, dtype):
1793-
data = (np.arange(-3, 3) / 4.).reshape([1, 3, 2]).astype(dtype)
1794-
self._compare(data, use_gpu=True)
1795-
self._compare(data, use_gpu=True)
1801+
data = (np.arange(-3, 3) / 4.).reshape(1, 3, 2).astype(dtype)
1802+
self._compare(data)
1803+
# TODO: rint op is not supported for float16
1804+
if dtype is np.float16:
1805+
return
1806+
self._compare_values(data)
1807+
x = [0.5, 0.5000001]
1808+
y = [0.0, 1.0]
1809+
self._compare_values(x, y=y)
1810+
1811+
# numpy example
1812+
x = [-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]
1813+
y = [-2., -2., -0., 0., 2., 2., 2.]
1814+
self._compare_values(x, y=y)
17961815

17971816
def testTypes(self):
17981817
for dtype in [np.float16, np.float32, np.float64]:

0 commit comments

Comments
 (0)