Skip to content

Commit f615378

Browse files
committed
Minor refactor in median. Add one and two element tests
1 parent f211253 commit f615378

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

src/api/c/median.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,12 @@ static double median(const af_array& in) {
4141
} else if (nElems == 2) {
4242
T result[2];
4343
AF_CHECK(af_get_data_ptr((void*)&result, in));
44-
if (input.isFloating()) {
45-
return division(result[0] + result[1], 2.0);
46-
} else {
47-
return division((float)result[0] + (float)result[1], 2.0);
48-
}
44+
return division(
45+
(static_cast<double>(result[0]) + static_cast<double>(result[1])),
46+
2.0);
4947
}
5048

51-
double mid = (nElems + 1) / 2;
49+
double mid = static_cast<double>(nElems + 1) / 2.0;
5250
af_seq mdSpan[1] = {af_make_seq(mid - 1, mid, 1)};
5351

5452
Array<T> sortedArr = sort<T>(input, 0, true);
@@ -68,11 +66,9 @@ static double median(const af_array& in) {
6866
if (nElems % 2 == 1) {
6967
result = resPtr[0];
7068
} else {
71-
if (input.isFloating()) {
72-
result = division(resPtr[0] + resPtr[1], 2);
73-
} else {
74-
result = division((float)resPtr[0] + (float)resPtr[1], 2);
75-
}
69+
result = division(
70+
static_cast<double>(resPtr[0]) + static_cast<double>(resPtr[1]),
71+
2.0);
7672
}
7773

7874
return result;
@@ -90,9 +86,9 @@ static af_array median(const af_array& in, const dim_t dim) {
9086

9187
Array<T> sortedIn = sort<T>(input, dim, true);
9288

93-
int dimLength = input.dims()[dim];
94-
double mid = (dimLength + 1) / 2;
95-
af_array left = 0;
89+
size_t dimLength = input.dims()[dim];
90+
double mid = static_cast<double>(dimLength + 1) / 2.0;
91+
af_array left = 0;
9692

9793
af_seq slices[4] = {af_span, af_span, af_span, af_span};
9894
slices[dim] = af_make_seq(mid - 1.0, mid - 1.0, 1.0);

test/median.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,18 @@ MEDIAN(float, uchar)
150150
MEDIAN(float, short)
151151
MEDIAN(float, ushort)
152152
MEDIAN(double, double)
153+
154+
TEST(Median, OneElement) {
155+
af::array in = randu(1, f32);
156+
157+
af::array out = median(in);
158+
ASSERT_ARRAYS_EQ(in, out);
159+
}
160+
161+
TEST(Median, TwoElements) {
162+
af::array in = randu(2, f32);
163+
164+
af::array out = median(in);
165+
af::array gold = mean(in);
166+
ASSERT_ARRAYS_EQ(gold, out);
167+
}

0 commit comments

Comments
 (0)