Skip to content

Commit

Permalink
Merge pull request #349 from MeasureTransport/mparno/mset-fixes
Browse files Browse the repository at this point in the history
Fixed bug in `FixedMultiindexSet` python bindings
  • Loading branch information
mparno authored Oct 5, 2023
2 parents aa08ee2 + 6ae0481 commit a9d3cd4
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 18 deletions.
10 changes: 5 additions & 5 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ cff-version: 1.2.0
title: Monotone Parameterization Toolkit (MParT)
message: 'If you use MParT, please cite it as below.'
type: software
version: 1.4.0
version: 2.1.1
authors:
- given-names: Matthew
family-names: Parno
email: parnomd@gmail.com
orcid: https://orcid.org/0000-0002-9419-2693
- given-names: Paul-Baptiste
family-names: Rubio
email: rubiop@mit.edu
orcid: https://orcid.org/0000-0002-9765-1162
- given-names: Daniel
family-names: Sharp
email: dannys4@vt.edu
orcid: https://orcid.org/0000-0002-0439-5084
- given-names: Paul-Baptiste
family-names: Rubio
email: rubiop@mit.edu
orcid: https://orcid.org/0000-0002-9765-1162
- given-names: Michael
family-names: Brennan
email: mcbrenn@mit.edu
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.13)
project(MParT VERSION 2.0.2)
project(MParT VERSION 2.1.1)

message(STATUS "Will install MParT to ${CMAKE_INSTALL_PREFIX}")

Expand Down
2 changes: 1 addition & 1 deletion MParT/TrainMapAdaptive.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct ATMOptions: public MapOptions, public TrainOptions {
/** Maximum number of iterations that do not improve error */
unsigned int maxPatience = 10;
/** Maximum number of coefficients in final expansion (including ALL dimensions of map) */
unsigned int maxSize = std::numeric_limits<unsigned int>::infinity();
unsigned int maxSize = std::numeric_limits<int>::max(); // <- use this instead of infinity because python doesn't have infinite ints
/** Multiindex representing the maximum degree in each input dimension */
MultiIndex maxDegrees;

Expand Down
2 changes: 1 addition & 1 deletion MParT/Utilities/ArrayConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ namespace mpart{
template<typename ScalarType, class MemorySpace>
StridedVector<ScalarType, MemorySpace> ConstVecToKokkos(const std::vector<ScalarType> &vec)
{
double* ptr = const_cast<double*>(vec.data());
ScalarType* ptr = const_cast<ScalarType*>(vec.data());
return Kokkos::View<ScalarType*, MemorySpace>(ptr, vec.size());
}

Expand Down
28 changes: 20 additions & 8 deletions bindings/python/src/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,23 @@ void mpart::binding::MultiIndexWrapper(py::module &m)
}))

.def(py::init( [](unsigned int dim,
Eigen::Matrix<unsigned int, Eigen::Dynamic, 1> &nzStarts,
Eigen::Matrix<unsigned int, Eigen::Dynamic, 1> &nzDims,
Eigen::Matrix<unsigned int, Eigen::Dynamic, 1> &nzOrders)
{
return new FixedMultiIndexSet<Kokkos::HostSpace>(dim,
VecToKokkos<unsigned int, Kokkos::HostSpace>(nzStarts),
VecToKokkos<unsigned int, Kokkos::HostSpace>(nzDims),
VecToKokkos<unsigned int, Kokkos::HostSpace>(nzOrders));
Eigen::Matrix<unsigned int, Eigen::Dynamic, 1> nzStartsIn,
Eigen::Matrix<unsigned int, Eigen::Dynamic, 1> nzDimsIn,
Eigen::Matrix<unsigned int, Eigen::Dynamic, 1> nzOrdersIn)
{
// Deep copy the arrays into Kokkos
Kokkos::View<unsigned int*,Kokkos::HostSpace> nzStarts("nzStarts", nzStartsIn.rows());
Kokkos::View<unsigned int*,Kokkos::HostSpace> nzDims("nzDims", nzDimsIn.rows());
Kokkos::View<unsigned int*,Kokkos::HostSpace> nzOrders("nzOrders", nzOrdersIn.rows());

for(unsigned int i=0; i<nzStartsIn.rows(); ++i)
nzStarts(i) = nzStartsIn(i);
for(unsigned int i=0; i<nzDimsIn.rows(); ++i)
nzDims(i) = nzDimsIn(i);
for(unsigned int i=0; i<nzOrdersIn.rows(); ++i)
nzOrders(i) = nzOrdersIn(i);

return new FixedMultiIndexSet<Kokkos::HostSpace>(dim, nzStarts, nzDims, nzOrders);
}))

.def(py::init<unsigned int, unsigned int>())
Expand All @@ -189,6 +198,9 @@ void mpart::binding::MultiIndexWrapper(py::module &m)
.def("__len__", &FixedMultiIndexSet<Kokkos::HostSpace>::Length)
.def("Length", &FixedMultiIndexSet<Kokkos::HostSpace>::Length)
.def("Size", &FixedMultiIndexSet<Kokkos::HostSpace>::Size)
.def("IndexToMulti", &FixedMultiIndexSet<Kokkos::HostSpace>::IndexToMulti)
.def("MultiToIndex", &FixedMultiIndexSet<Kokkos::HostSpace>::MultiToIndex)

#if defined(MPART_HAS_CEREAL)
.def("Serialize", [](FixedMultiIndexSet<Kokkos::HostSpace> const &mset, std::string const &filename){
std::ofstream os(filename);
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license={file="LICENSE.txt"}
readme="README.md"
requires-python = ">=3.7"
description="A Monotone Parameterization Toolkit"
version="2.0.2"
version="2.1.1"
keywords=["Measure Transport", "Monotone", "Transport Map", "Isotonic Regression", "Triangular", "Knothe-Rosenblatt"]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ def get_install_locations():
package_dir={'mpart': 'bindings/python/package'},
package_data={'mpart':['**/*pympart*']},
include_package_data=True,
cmake_args=['-DKokkos_ENABLE_THREADS:BOOL=ON', f'-DSKBUILD_LIB_RPATH={lib_folder}', f'-DSKBUILD_SITE_PATH={site_folder}', '-DPYTHON_INSTALL_SUFFIX=bindings/python/package/', '-DMPART_JULIA:BOOL=OFF', '-DMPART_MATLAB:BOOL=OFF', '-DMPART_BUILD_TESTS:BOOL=OFF', '-DMPART_PYTHON:BOOL=ON', '-DPYTHON_INSTALL_PREFIX=']
cmake_args=['-DKokkos_ENABLE_THREADS:BOOL=ON', '-DKokkos_ENABLE_THREADS=ON', f'-DSKBUILD_LIB_RPATH={lib_folder}', f'-DSKBUILD_SITE_PATH={site_folder}', '-DPYTHON_INSTALL_SUFFIX=bindings/python/package/', '-DMPART_JULIA:BOOL=OFF', '-DMPART_MATLAB:BOOL=OFF', '-DMPART_BUILD_TESTS:BOOL=OFF', '-DMPART_PYTHON:BOOL=ON', '-DPYTHON_INSTALL_PREFIX=']
)

0 comments on commit a9d3cd4

Please sign in to comment.