Skip to content

Commit cf18244

Browse files
authored
fix shift_idx (#97)
1 parent ef1a524 commit cf18244

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

include/ttl/bits/std_shape.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ N product(Iterator begin, Iterator end)
1717
};
1818

1919
template <size_t off, typename T, size_t r, size_t... Is>
20-
constexpr std::array<T, r - 1> shift_idx(const std::array<T, r> &a,
21-
std::index_sequence<Is...>)
20+
constexpr std::array<T, r - off> shift_idx(const std::array<T, r> &a,
21+
std::index_sequence<Is...>)
2222
{
23-
return std::array<T, r - 1>({std::get<Is + off>(a)...});
23+
return std::array<T, r - off>({std::get<Is + off>(a)...});
2424
}
2525

2626
using rank_t = uint8_t;

tests/test_shape.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ TEST(shape_test, test_flatten)
181181
ASSERT_EQ(s8, ttl::shape<1>(720));
182182
}
183183

184-
struct S {
184+
struct Test_constexpr {
185185
static constexpr auto shape = ttl::make_shape(2, 3);
186186
};
187187

@@ -195,3 +195,12 @@ TEST(shape_test, test_batch_vectorize)
195195
auto vs1 = ttl::vectorize(shape);
196196
ASSERT_EQ(vs1, ttl::make_shape(2, 3, 1));
197197
}
198+
199+
TEST(shape_test, test_subshape)
200+
{
201+
auto s = ttl::make_shape(1, 2, 3, 4);
202+
auto t1 = s.subshape(); // (2, 3, 4)
203+
ASSERT_EQ(t1.size(), 24);
204+
auto t2 = s.subshape<2>(); // (3, 4)
205+
ASSERT_EQ(t2.size(), 12);
206+
}

0 commit comments

Comments
 (0)