Skip to content

Commit c0c6ffb

Browse files
authored
fix flatten for rank 0 (#75)
1 parent 94bb2b6 commit c0c6ffb

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

include/ttl/bits/std_tensor_reshape.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace internal
99
{
1010
template <typename R, typename S, typename D, typename A>
1111
struct flattener {
12-
using S1 = typename S::template subshape_t<S::rank - 1>;
12+
using S1 = basic_shape<1, typename S::dimension_type>;
1313
using vector =
1414
basic_tensor<R, S1, D, typename basic_tensor_traits<R, A, D>::Access>;
1515

tests/test_tensor_reshape.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,40 @@ TEST(tensor_reshape_test, test_chunk)
1616
static_assert(std::is_same<decltype(z), ttl::tensor_view<int, 1>>::value,
1717
"");
1818
}
19+
20+
template <typename R, ttl::rank_t r>
21+
void test_flatten(const ttl::tensor<R, r> &t)
22+
{
23+
{
24+
auto v = ttl::flatten(t);
25+
static_assert(std::is_same<decltype(v), ttl::vector_ref<R>>::value, "");
26+
ASSERT_EQ(v.shape().size(), t.shape().size());
27+
}
28+
{
29+
auto v = ttl::flatten(ttl::ref(t));
30+
static_assert(std::is_same<decltype(v), ttl::vector_ref<R>>::value, "");
31+
ASSERT_EQ(v.shape().size(), t.shape().size());
32+
}
33+
{
34+
auto v = ttl::flatten(ttl::view(t));
35+
static_assert(std::is_same<decltype(v), ttl::vector_view<R>>::value,
36+
"");
37+
ASSERT_EQ(v.shape().size(), t.shape().size());
38+
}
39+
}
40+
41+
TEST(tensor_reshape_test, test_flatten)
42+
{
43+
{
44+
ttl::tensor<int, 0> x;
45+
test_flatten(x);
46+
}
47+
{
48+
ttl::tensor<int, 1> x(4);
49+
test_flatten(x);
50+
}
51+
{
52+
ttl::tensor<int, 2> x(4, 5);
53+
test_flatten(x);
54+
}
55+
}

0 commit comments

Comments
 (0)