Skip to content

Commit 404de08

Browse files
Vladimir DimicVladimir Dimic
authored andcommitted
Update Functor-based matrices according to the recent AMF changes (#51)
Also, re-enable and improve outer unit test
1 parent 388ba38 commit 404de08

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed

include/alp/reference/matrix.hpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,20 @@ namespace alp {
881881
>::amf_type type;
882882
};
883883

884+
/** Specialization for containers that allocate storage */
885+
template< typename Structure, typename ImfC, enum Backend backend, typename Lambda >
886+
struct determine_amf_type< Structure, view::Functor< Lambda >, imf::Id, ImfC, backend > {
887+
888+
static_assert(
889+
std::is_same< ImfC, imf::Id >::value || std::is_same< ImfC, imf::Zero >::value,
890+
"Incompatible combination of parameters provided to determine_amf_type."
891+
);
892+
893+
typedef typename storage::AMFFactory::FromPolynomial<
894+
storage::polynomials::None_type
895+
>::amf_type type;
896+
};
897+
884898
} // namespace internal
885899

886900
/**
@@ -1117,7 +1131,7 @@ namespace alp {
11171131
> * = nullptr
11181132
>
11191133
Matrix( std::function< bool() > initialized, const size_t rows, const size_t cols, typename ViewType::applied_to lambda ) :
1120-
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, rows, cols, lambda ) {}
1134+
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, imf::Id( rows ), imf::Id( cols ), lambda ) {}
11211135

11221136
/**
11231137
* Constructor for a view over another functor-based matrix.
@@ -1321,8 +1335,8 @@ namespace alp {
13211335
internal::requires_allocation< ViewType >::value
13221336
> * = nullptr
13231337
>
1324-
Matrix( bool initialized, const size_t dim, typename ViewType::applied_to lambda ) :
1325-
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, dim, dim, lambda ) {}
1338+
Matrix( std::function< bool() > initialized, const size_t dim, typename ViewType::applied_to lambda ) :
1339+
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, imf::Id( dim ), imf::Id( dim ), lambda ) {}
13261340

13271341
/**
13281342
* Constructor for a view over another functor-based matrix.
@@ -1499,7 +1513,7 @@ namespace alp {
14991513
> * = nullptr
15001514
>
15011515
Matrix( std::function< bool() > initialized, const size_t dim, typename ViewType::applied_to lambda ) :
1502-
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, dim, dim, lambda ) {}
1516+
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, imf::Id( dim ), imf::Id( dim ), lambda ) {}
15031517

15041518
/**
15051519
* Constructor for a view over another functor-based matrix.
@@ -1691,8 +1705,8 @@ namespace alp {
16911705
internal::requires_allocation< ViewType >::value
16921706
> * = nullptr
16931707
>
1694-
Matrix( bool initialized, const size_t dim, typename ViewType::applied_to lambda ) :
1695-
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, dim, lambda ) {}
1708+
Matrix( std::function< bool() > initialized, const size_t dim, typename ViewType::applied_to lambda ) :
1709+
internal::FunctorBasedMatrix< T, ImfR, ImfC, typename View::applied_to >( initialized, imf::Id( dim ), imf::Id( dim ), lambda ) {}
16961710

16971711
/**
16981712
* Constructor for a view over another functor-based matrix.

tests/unit/CMakeLists.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,9 @@ add_grb_executables( dense_mxm dense_mxm.cpp
242242
BACKENDS alp_reference
243243
)
244244

245-
# Temporarily disable until views on functor-based containers are fully functional
246-
#add_grb_executables( dense_outer dense_outer.cpp
247-
# BACKENDS alp_reference
248-
#)
245+
add_grb_executables( dense_outer dense_outer.cpp
246+
BACKENDS alp_reference
247+
)
249248

250249
add_grb_executables( dense_structured_matrix dense_structured_matrix.cpp
251250
BACKENDS alp_reference

tests/unit/dense_outer.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void outer_stdvec_as_matrix(
6262
const Operator oper
6363
) {
6464

65-
print_stdvec_as_matrix("vA", vA, n, 1, 1);
65+
print_stdvec_as_matrix("vA", vA, m, 1, 1);
6666
print_stdvec_as_matrix("vB", vB, 1, n, n);
6767
print_stdvec_as_matrix("vC - PRE", vC, m, n, n);
6868

@@ -129,13 +129,18 @@ void alpProgram( const size_t &n, alp::RC &rc ) {
129129

130130
alp::Semiring< alp::operators::add< T >, alp::operators::mul< T >, alp::identities::zero, alp::identities::one > ring;
131131

132-
T one = ring.getOne< T >();
133132
T zero = ring.getZero< T >();
134133

135134
// allocate
136135
const size_t m = 2 * n;
137-
std::vector< T > u_data( m, one );
138-
std::vector< T > v_data( n, one );
136+
std::vector< T > u_data( m );
137+
for( size_t i = 0; i < u_data.size(); ++i ) {
138+
u_data[ i ] = i + 1;
139+
}
140+
std::vector< T > v_data( n );
141+
for( size_t i = 0; i < v_data.size(); ++i ) {
142+
v_data[ i ] = i + 1;
143+
}
139144
std::vector< T > M_data( n, zero );
140145

141146
alp::Vector< T > u( m );
@@ -153,14 +158,16 @@ void alpProgram( const size_t &n, alp::RC &rc ) {
153158

154159
std::cout << "Is uvT initialized after initializing source containers? " << alp::internal::getInitialized( uvT ) << "\n";
155160

161+
print_matrix( "uvT", uvT );
162+
156163
std::vector< T > uvT_test( m * n, zero );
157164
outer_stdvec_as_matrix( uvT_test, n, u_data, v_data, m, n, ring.getMultiplicativeOperator() );
158165
diff_stdvec_matrix( uvT_test, m, n, n, uvT );
159166

160167
// Example when outer product takes the same vector as both inputs.
161168
// This operation results in a symmetric positive definite matrix.
162169
auto vvT = alp::outer( v, ring.getMultiplicativeOperator() );
163-
print_matrix( "vvT", uvT );
170+
print_matrix( "vvT", vvT );
164171

165172
std::vector< T > vvT_test( n * n, zero );
166173
outer_stdvec_as_matrix( vvT_test, n, v_data, v_data, n, n, ring.getMultiplicativeOperator() );

0 commit comments

Comments
 (0)