Skip to content

Commit 413c796

Browse files
authored
ttl/algorithm (#30)
* init ttl/algorithm * min/max * fix build
1 parent 4b483ac commit 413c796

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed

include/ttl/algorithm

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// # -*- mode: c++ -*-
2+
#pragma once
3+
4+
#include <ttl/bits/std_tensor_algo.hpp>
5+
6+
namespace ttl
7+
{
8+
using ttl::internal::argmax;
9+
using ttl::internal::cast;
10+
using ttl::internal::fill;
11+
using ttl::internal::hamming_distance;
12+
using ttl::internal::max;
13+
using ttl::internal::min;
14+
using ttl::internal::sum;
15+
} // namespace ttl
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#pragma once
2+
#include <algorithm>
3+
#include <functional>
4+
5+
#include <ttl/bits/std_shape.hpp>
6+
#include <ttl/bits/std_tensor.hpp>
7+
8+
namespace ttl
9+
{
10+
namespace internal
11+
{
12+
template <typename R, typename S, typename D = typename S::dimension_type>
13+
D argmax(const basic_tensor_view<R, 1, S> &t)
14+
{
15+
return std::max_element(t.data(), t.data_end()) - t.data();
16+
}
17+
18+
template <typename R, typename R1, rank_t r, typename S>
19+
void cast(const basic_tensor_view<R, r, S> &x,
20+
const basic_tensor_ref<R1, r, S> &y)
21+
{
22+
std::transform(x.data(), x.data_end(), y.data(),
23+
[](const R &e) -> R1 { return static_cast<R1>(e); });
24+
}
25+
26+
template <typename R, rank_t r, typename S>
27+
void fill(const basic_tensor_ref<R, r, S> &t, const R &x)
28+
{
29+
std::fill(t.data(), t.data_end(), x);
30+
}
31+
32+
template <typename R, rank_t r, typename S,
33+
typename D = typename S::dimension_type>
34+
D hamming_distance(const basic_tensor_view<R, r, S> &x,
35+
const basic_tensor_view<R, r, S> &y)
36+
{
37+
return std::inner_product(x.data(), x.data_end(), y.data(),
38+
static_cast<D>(0), std::plus<D>(),
39+
std::not_equal_to<R>());
40+
}
41+
42+
template <typename R, rank_t r, typename S>
43+
R max(const basic_tensor_view<R, r, S> &t)
44+
{
45+
return *std::max_element(t.data(), t.data_end());
46+
}
47+
48+
template <typename R, rank_t r, typename S>
49+
R min(const basic_tensor_view<R, r, S> &t)
50+
{
51+
return *std::min_element(t.data(), t.data_end());
52+
}
53+
54+
template <typename R, rank_t r, typename S>
55+
R sum(const basic_tensor_view<R, r, S> &t)
56+
{
57+
return std::accumulate(t.data(), t.data_end(), static_cast<R>(0));
58+
}
59+
60+
} // namespace internal
61+
} // namespace ttl

tests/test_algo.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include "testing.hpp"
2+
3+
#include <ttl/algorithm>
4+
#include <ttl/tensor>
5+
6+
TEST(tensor_algo_test, test_argmax)
7+
{
8+
using R = float;
9+
ttl::tensor<R, 1> t(10);
10+
std::iota(t.data(), t.data_end(), 1);
11+
ASSERT_EQ(static_cast<uint32_t>(9), ttl::argmax(view(t)));
12+
}
13+
14+
TEST(tensor_algo_test, test_cast)
15+
{
16+
int n = 10;
17+
18+
ttl::tensor<float, 1> x(n);
19+
std::generate(x.data(), x.data_end(), [v = 0.1]() mutable {
20+
auto u = v;
21+
v += 0.2;
22+
return u;
23+
});
24+
25+
ttl::tensor<int, 1> y(n);
26+
ttl::cast(view(x), ref(y));
27+
28+
ASSERT_EQ(5, ttl::sum(view(y)));
29+
}
30+
31+
TEST(tensor_algo_test, test_fill)
32+
{
33+
{
34+
using R = int;
35+
ttl::tensor<R, 1> t(10);
36+
ttl::fill(ref(t), 1);
37+
}
38+
{
39+
using R = float;
40+
ttl::tensor<R, 1> t(10);
41+
ttl::fill(ref(t), static_cast<R>(1.1));
42+
}
43+
}
44+
45+
TEST(tensor_algo_test, test_hamming_distance)
46+
{
47+
using R = int;
48+
int n = 0xffff;
49+
ttl::tensor<R, 1> x(n);
50+
ttl::fill(ref(x), -1);
51+
ttl::tensor<R, 1> y(n);
52+
ttl::fill(ref(y), 1);
53+
ASSERT_EQ(static_cast<uint32_t>(n),
54+
ttl::hamming_distance(view(x), view(y)));
55+
}
56+
57+
TEST(tensor_algo_test, test_summaries)
58+
{
59+
using R = int;
60+
const int n = 10;
61+
ttl::tensor<R, 1> x(n);
62+
std::iota(x.data(), x.data_end(), -5);
63+
ASSERT_EQ(-5, ttl::min(view(x)));
64+
ASSERT_EQ(4, ttl::max(view(x)));
65+
ASSERT_EQ(-5, ttl::sum(view(x)));
66+
}

0 commit comments

Comments
 (0)