Skip to content

Commit ac33199

Browse files
committed
Change dot function to operator() overload, so a lambda function can be passed
1 parent 7e059ff commit ac33199

File tree

7 files changed

+75
-46
lines changed

7 files changed

+75
-46
lines changed

README.md

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
# Lanczos solver,
2-
Computes the matrix-vector product sqrt(M)·v using a recursive algorithm.
3-
For that, it requires a functor with a "dot" function that takes an output real* array and an input real* (both in device memory if compiled in CUDA mode or host memory otherwise) as:
4-
```c++
5-
virtual void dot(real* in_v, real * out_Mv) override;
6-
```
7-
The functor must inherit ```lanczos::MatrixDot``` (see example.cu).
8-
This function must fill "out" with the result of performing the M·v dot product- > out = M·a_v.
9-
If M has size NxN and the cost of the dot product is O(M). The total cost of the algorithm is O(m·M). Where m << N.
10-
If M·v performs a dense M-V product, the cost of the algorithm would be O(m·N^2).
2+
Computes the matrix-vector product sqrt(M)·v using a recursive algorithm.
3+
For that, it requires a functor that takes an output real* array and an input real* (both in device memory if compiled in CUDA mode or host memory otherwise) as:
4+
```c++
5+
void operator()(real* in_v, real * out_Mv) override;
6+
```
7+
The functor can inherit ```lanczos::MatrixDot``` (see example.cu).
8+
This function must fill "out" with the result of performing the M·v dot product- > out = M·a_v.
9+
If M has size NxN and the cost of the dot product is O(M). The total cost of the algorithm is O(m·M). Where m << N.
10+
If M·v performs a dense M-V product, the cost of the algorithm would be O(m·N^2).
1111
1212
This is a header-only library, although a shared library can be compiled instead.
1313
14-
## Usage:
14+
## Usage:
1515
16-
See example.cu for an usage example that can be compiled to work in GPU or CPU mode instinctively.
17-
See example.cpp for a CPU only example.
16+
See example.cu for an usage example that can be compiled to work in GPU or CPU mode instinctively.
17+
See example.cpp for a CPU only example.
1818
19-
Let us go through the remaining one, a GPU-only example.
19+
Let us go through the remaining one, a GPU-only example.
2020
2121
Create the module:
2222
```c++
@@ -30,8 +30,8 @@ Write a functor that computes the product between the original matrix and a give
3030
struct DiagonalMatrix: public lanczos::MatrixDot{
3131
int size;
3232
DiagonalMatrix(int size): size(size){}
33-
34-
void dot(real* v, real* Mv) override{
33+
34+
void operator()(real* v, real* Mv) override{
3535
//An example diagonal matrix
3636
for(int i=0; i<size; i++){
3737
Mv[i] = 2*v[i];
@@ -42,7 +42,7 @@ struct DiagonalMatrix: public lanczos::MatrixDot{
4242

4343
```
4444
45-
Provide the solver with an instance of the functor and the target vector:
45+
Provide the solver with an instance of the functor and the target vector:
4646
4747
```c++
4848
int size = 10;
@@ -55,7 +55,7 @@ Provide the solver with an instance of the functor and the target vector:
5555
DiagonalMatrix dot(size);
5656
//Call the solver
5757
real* d_result = thrust::raw_pointer_cast(result.data());
58-
real* d_v = thrust::raw_pointer_cast(v.data());
58+
real* d_v = thrust::raw_pointer_cast(v.data());
5959
real tolerance = 1e-6;
6060
int numberIterations = lanczos.run(dot, d_result, d_v, tolerance, size);
6161
int iterations = 100;
@@ -64,24 +64,24 @@ Provide the solver with an instance of the functor and the target vector:
6464
The run function returns the number of iterations that were needed to achieve the requested accuracy.
6565
The runIterations returns the residual after the requested iterations.
6666

67-
## Other functions:
67+
## Other functions:
6868

69-
After a certain number of iterations, if convergence was not achieved, the module will give up and throw an error.
70-
To increase this threshold you can use this function:
69+
After a certain number of iterations, if convergence was not achieved, the module will give up and throw an error.
70+
To increase this threshold you can use this function:
7171
```c++
7272
lanczos::Solver::setIterationHardLimit(int newlimit);
7373
```
74-
## Compilation:
75-
This library requires lapacke and cblas (can be replaced by MKL). In GPU mode CUDA is also needed.
76-
Note, however, that the heavy-weight of this solver comes from the Matrix-vector multiplication that must provided by the user. The main benefit of the CUDA mode is not an increased performance of the internal library code, but the fact that the input/output arrays will live in the GPU (saving potential memory copies).
77-
## Optional macros:
74+
## Compilation:
75+
This library requires lapacke and cblas (can be replaced by MKL). In GPU mode CUDA is also needed.
76+
Note, however, that the heavy-weight of this solver comes from the Matrix-vector multiplication that must provided by the user. The main benefit of the CUDA mode is not an increased performance of the internal library code, but the fact that the input/output arrays will live in the GPU (saving potential memory copies).
77+
## Optional macros:
7878
79-
**CUDA_ENABLED**: Will compile a GPU enabled shared library, the solver expects input/output arrays to be in the GPU and most of the computations will be carried out in the GPU. Requires a working CUDA environment.
80-
**DOUBLE_PRECISION**: The library is compiled in single precision by default. This macro switches to double precision, making ```lanczos::real``` be a typedef to double.
81-
**USE_MKL**: Will include mkl.h instead of lapacke and cblas. You will have to modify the compilation flags accordingly.
82-
**SHARED_LIBRARY_COMPILATION**: The Makefile uses this macro to compile a shared library. By default, this library is header only.
79+
**CUDA_ENABLED**: Will compile a GPU enabled shared library, the solver expects input/output arrays to be in the GPU and most of the computations will be carried out in the GPU. Requires a working CUDA environment.
80+
**DOUBLE_PRECISION**: The library is compiled in single precision by default. This macro switches to double precision, making ```lanczos::real``` be a typedef to double.
81+
**USE_MKL**: Will include mkl.h instead of lapacke and cblas. You will have to modify the compilation flags accordingly.
82+
**SHARED_LIBRARY_COMPILATION**: The Makefile uses this macro to compile a shared library. By default, this library is header only.
8383
84-
See the Makefile for further instructions.
84+
See the Makefile for further instructions.
8585
8686
## Python interface
8787
@@ -91,13 +91,13 @@ See python/example.py for more information.
9191
The root folder's Makefile will try to compile the python library as well. It expects pybind11 to be placed under the extern/ folder. Pybind11 is included as a submodule, so make sure to clone this repository with --recursive.
9292
Note that the python wrapper can only be compiled in CPU mode.
9393
94-
## References:
94+
## References:
95+
96+
[1] Krylov subspace methods for computing hydrodynamic interactions in Brownian dynamics simulations J. Chem. Phys. 137, 064106 (2012); doi: 10.1063/1.4742347
9597
96-
[1] Krylov subspace methods for computing hydrodynamic interactions in Brownian dynamics simulations J. Chem. Phys. 137, 064106 (2012); doi: 10.1063/1.4742347
97-
98-
## Some notes:
98+
## Some notes:
9999
100-
From what I have seen, this algorithm converges to an error of ~1e-3 in a few steps (<5) and from that point a lot of iterations are needed to lower the error.
101-
It usually achieves machine precision in under 50 iterations.
100+
From what I have seen, this algorithm converges to an error of ~1e-3 in a few steps (<5) and from that point a lot of iterations are needed to lower the error.
101+
It usually achieves machine precision in under 50 iterations.
102102
103-
If the matrix does not have a sqrt (not positive definite, not symmetric...) it will usually be reflected as a nan in the current error estimation. In this case an exception will be thrown.
103+
If the matrix does not have a sqrt (not positive definite, not symmetric...) it will usually be reflected as a nan in the current error estimation. In this case an exception will be thrown.

example.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ using real = lanczos::real;
2222
struct DiagonalMatrix: public lanczos::MatrixDot{
2323
int size;
2424
DiagonalMatrix(int size): size(size){}
25-
26-
void dot(real* v, real* Mv) override{
25+
26+
void operator()(real* v, real* Mv) override{
2727
//an example diagonal matrix
2828
for(int i=0; i<size; i++){
2929
Mv[i] = (2+i/10.0)*v[i]*2;

example.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct DiagonalMatrix: public lanczos::MatrixDot{
2323
int size;
2424
DiagonalMatrix(int size): size(size){}
2525

26-
void dot(real* v, real* Mv) override{
26+
void operator()(real* v, real* Mv) override{
2727
//An example diagonal matrix
2828
for(int i=0; i<size; i++){
2929
Mv[i] = (2+i/10.0)*v[i];

include/LanczosAlgorithm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ namespace lanczos{
147147
auto d_w = detail::getRawPointer(w);
148148
/*w = D·vi*/
149149
dot->setSize(N);
150-
dot->dot(d_V+N*i, d_w);
150+
dot->operator()(d_V+N*i, d_w);
151151
if(i>0){
152152
/*w = w-h[i-1][i]·vi*/
153153
real alpha = -hsup[i-1];

include/LanczosAlgorithm.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ Some notes:
2222
#include"utils/defines.h"
2323
#include"utils/device_blas.h"
2424
#include<vector>
25-
#include<memory>
25+
#include <memory>
26+
#include<functional>
2627
#include"utils/device_container.h"
2728
#include"utils/MatrixDot.h"
2829
namespace lanczos{
@@ -39,13 +40,21 @@ namespace lanczos{
3940
int run(MatrixDot &dot, real *Bv, const real* v, real tolerance, int N){
4041
return run(&dot, Bv, v, tolerance, N);
4142
}
43+
int run(std::function<void(real*, real*)> dot, real *Bv, const real* v, real tolerance, int N){
44+
auto lanczos_dot = createMatrixDotAdaptor(dot);
45+
return run(lanczos_dot, Bv, v, tolerance, N);
46+
}
4247
//Given a Dotctor that computes a product M·v (where M is handled by Dotctor ), computes Bv = sqrt(M)·v
4348
//Returns the residual after numberIterations iterations
4449
//B = sqrt(M)
4550
real runIterations(MatrixDot *dot, real *Bz, const real*z, int numberIterations, int N);
4651
real runIterations(MatrixDot &dot, real *Bv, const real* v, int numberIterations, int N){
4752
return runIterations(&dot, Bv, v, numberIterations, N);
4853
}
54+
real runIterations(std::function<void(real*, real*)> dot, real *Bv, const real* v, int numberIterations, int N){
55+
auto lanczos_dot = createMatrixDotAdaptor(dot);
56+
return runIterations(lanczos_dot, Bv, v, numberIterations, N);
57+
}
4958

5059
void setIterationHardLimit(int newLimit){this->iterationHardLimit = newLimit;}
5160

include/utils/MatrixDot.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
#ifndef LANCZOS_MATRIX_DOT_H
22
#define LANCZOS_MATRIX_DOT_H
3-
#include"defines.h"
3+
#include "defines.h"
4+
#include<functional>
45
namespace lanczos{
56

67
struct MatrixDot{
78
void setSize(int newsize){this->m_size = newsize;}
8-
virtual void dot(real* v, real*Mv) = 0;
9+
//virtual void dot(real* v, real*Mv) = 0;
10+
virtual void operator()(real* v, real*Mv) = 0;
911
protected:
1012
int m_size;
1113
};
14+
15+
//Transforms any callable into a MatrixDot valid to use with Lanczos
16+
template<class Foo>
17+
struct MatrixDotAdaptor: public lanczos::MatrixDot{
18+
Foo& foo;
19+
MatrixDotAdaptor(Foo &&foo):foo(foo){}
20+
void operator()(real* v, real* Mv) override{foo(v,Mv);}
21+
};
22+
23+
template<class Foo>
24+
auto createMatrixDotAdaptor(Foo &&foo){
25+
return MatrixDotAdaptor<Foo>(foo);
26+
}
27+
1228
}
1329
#endif

python/python_wrapper.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ using namespace pybind11::literals;
1515
//This class allows to inherit lanczos::MatrixDot from python
1616
struct MatrixDotTrampoline: public lanczos::MatrixDot{
1717
using MatrixDot::MatrixDot;
18-
19-
void dot(real* v, real* Mv) override{
18+
19+
void dot(real* v, real* Mv){
20+
this->operator()(v, Mv);
21+
}
22+
23+
void operator()(real* v, real* Mv) override{
2024
pybind11::gil_scoped_acquire gil; // Acquire the GIL while in this scope.
2125
// Try to look up the overridden method on the Python side.
2226
pybind11::function overridef = pybind11::get_override(this, "dot");
@@ -60,7 +64,7 @@ class PyLanczos{
6064
PYBIND11_MODULE(LANCZOS_PYTHON_NAME, m){
6165
py::class_<lanczos::MatrixDot, MatrixDotTrampoline>(m, "MatrixDot", "The virtual class required by the Lanczos solver").
6266
def(py::init<>()).
63-
def("dot", &lanczos::MatrixDot::dot,
67+
def("dot", &lanczos::MatrixDot::operator(),
6468
"Given a result (Mv) and a vector (v), this method must write in Mv the result of multiplying the target matrix and v.",
6569
"v"_a, "The input vector", "Mv"_a, "The output result vector");
6670

0 commit comments

Comments
 (0)