Skip to content

Commit

Permalink
Update bindings for CreateSigmoid.
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Parno committed May 11, 2024
1 parent 9eb557b commit 75bfc3d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 1 deletion.
5 changes: 5 additions & 0 deletions bindings/julia/src/MapFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ void mpart::binding::MapFactoryWrapper(jlcxx::Module &mod) {
mod.method("CreateTriangular", &MapFactory::CreateTriangular<MemorySpace>);

// CreateSigmoidComponent
mod.method("CreateSigmoidComponent", [](unsigned int inDim, unsigned int offDiagOrder, unsigned int crossOrder, jlcxx::ArrayRef<double,1> centers, MapOptions opts){
StridedVector<const double, MemorySpace> centersVec = JuliaToKokkos(centers);
return MapFactory::CreateSigmoidComponent<Kokkos::HostSpace>(inDim, offDiagOrder, crossOrder, centersVec, opts);
});

mod.method("CreateSigmoidComponent", [](unsigned int inDim, unsigned int totalOrder, jlcxx::ArrayRef<double,1> centers, MapOptions opts){
StridedVector<const double, MemorySpace> centersVec = JuliaToKokkos(centers);
return MapFactory::CreateSigmoidComponent<Kokkos::HostSpace>(inDim, totalOrder, centersVec, opts);
Expand Down
6 changes: 6 additions & 0 deletions bindings/julia/src/MapOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ void mpart::binding::MapOptionsWrapper(jlcxx::Module &mod) {
mod.add_bits<SigmoidTypes>("__SigmoidTypes", jlcxx::julia_type("CppEnum"));
mod.set_const("__Logistic", SigmoidTypes::Logistic);

// SigmoidSumTypes
mod.add_bits<SigmoidSumSizeType>("__SigmoidSumSizeType", jlcxx::julia_type("CppEnum"));
mod.set_const("__Linear", SigmoidSumSizeType::Linear);
mod.set_const("__Constant", SigmoidSumSizeType::Constant);

// EdgeTypes: TODO: SoftPlus overlaps with PosFuncTypes, needs to be fixed
// mod.add_bits<EdgeTypes>("__EdgeTypes", jlcxx::julia_type("CppEnum"));
// mod.set_const("__SoftPlus", EdgeTypes::SoftPlus);
Expand All @@ -45,6 +50,7 @@ void mpart::binding::MapOptionsWrapper(jlcxx::Module &mod) {
.method("__posFuncType!", [](MapOptions &opts, unsigned int f){ opts.posFuncType = static_cast<PosFuncTypes>(f); })
.method("__quadType!", [](MapOptions &opts, unsigned int quad){ opts.quadType = static_cast<QuadTypes>(quad); })
.method("__sigmoidType!", [](MapOptions &opts, unsigned int sig){ opts.sigmoidType = static_cast<SigmoidTypes>(sig); })
.method("__sigmoidBasisSumType!", [](MapOptions &opts, unsigned int sig){ opts.sigmoidBasisSumType = static_cast<SigmoidSumSizeType>(sig); })
.method("__edgeType!", [](MapOptions &opts, unsigned int edge){ opts.edgeType = static_cast<EdgeTypes>(edge); })
.method("__edgeShape!", [](MapOptions &opts, double width){ opts.edgeShape = width; })
.method("__quadAbsTol!", [](MapOptions &opts, double tol){ opts.quadAbsTol = tol; })
Expand Down
10 changes: 9 additions & 1 deletion bindings/matlab/mat/MapOptions/MapOptions.m
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
properties (Access = public)
basisType = BasisTypes.ProbabilistHermite;
sigmoidType = SigmoidTypes.Logistic;
sigmoidBasisSumType = SigmoidSumSizeType.Linear;
edgeType = EdgeTypes.SoftPlus;
posFuncType = PosFuncTypes.SoftPlus;
quadType = QuadTypes.AdaptiveSimpson;
Expand Down Expand Up @@ -37,6 +38,9 @@
function obj = set.sigmoidType(obj,type)
obj.sigmoidType = type;
end
function obj = set.sigmoidBasisSumType(obj,type)
obj.sigmoidBasisSumType = type;
end
function obj = set.edgeShape(obj,value)
obj.edgeShape = value;
end
Expand Down Expand Up @@ -84,6 +88,7 @@
optionsArray{14} = obj.basisUB;
optionsArray{15} = obj.basisNorm;
optionsArray{16} = obj.nugget;
optionsArray{17} = obj.sigmoidBasisSumType;
end

function res = eq(obj1, obj2)
Expand All @@ -94,6 +99,7 @@
res = res && isequal(obj1.basisNorm, obj2.basisNorm);
res = res && isequal(obj1.posFuncType, obj2.posFuncType);
res = res && isequal(obj1.sigmoidType, obj2.sigmoidType);
res = res && isequal(obj1.sigmoidBasisSumType, obj2.sigmoidBasisSumType);
res = res && isequal(obj1.edgeType, obj2.edgeType);
res = res && isequal(obj1.edgeShape, obj2.edgeShape);
res = res && isequal(obj1.quadType, obj2.quadType);
Expand All @@ -110,7 +116,8 @@ function Serialize(obj,filename)
MParT_('MapOptions_Serialize',filename, char(obj.basisType), ...
char(obj.sigmoidType), char(obj.edgeType), char(obj.posFuncType), char(obj.quadType), ...
obj.quadAbsTol, obj.quadRelTol, obj.quadMaxSub, obj.quadMinSub, obj.edgeShape, ...
obj.quadPts, obj.contDeriv, obj.basisLB, obj.basisUB, obj.basisNorm, obj.nugget)
obj.quadPts, obj.contDeriv, obj.basisLB, obj.basisUB, obj.basisNorm, obj.nugget,
char(obj.sigmoidBasisSumType))
end

function obj = Deserialize(obj,filename)
Expand Down Expand Up @@ -143,6 +150,7 @@ function Serialize(obj,filename)
obj.basisUB = optionsArray{14};
obj.basisNorm = optionsArray{15};
obj.nugget = optionsArray{16};
obj.sigmoidBasisSumType = optionsArray{17};
end
end

Expand Down
6 changes: 6 additions & 0 deletions bindings/matlab/mat/MapOptions/SigmoidSumSizeType.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
classdef SigmoidSumSizeType
enumeration
Linear
Constant
end
end
1 change: 1 addition & 0 deletions bindings/python/src/MapFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void mpart::binding::MapFactoryWrapper(py::module &m)
m.def(isDevice? "dCreateSingleEntryMap" : "CreateSingleEntryMap", &MapFactory::CreateSingleEntryMap<MemorySpace>);

// CreateSigmoidComponent
m.def(isDevice? "dCreateSigmoidComponent" : "CreateSigmoidComponent", py::overload_cast<unsigned int, unsigned int, unsigned int, Eigen::Ref<const Eigen::RowVectorXd>, MapOptions>(&MapFactory::CreateSigmoidComponent<MemorySpace>));
m.def(isDevice? "dCreateSigmoidComponent" : "CreateSigmoidComponent", py::overload_cast<unsigned int, unsigned int, Eigen::Ref<const Eigen::RowVectorXd>, MapOptions>(&MapFactory::CreateSigmoidComponent<MemorySpace>));
m.def(isDevice? "dCreateSigmoidComponent" : "CreateSigmoidComponent", py::overload_cast<FixedMultiIndexSet<MemorySpace>, Eigen::Ref<const Eigen::RowVectorXd>, MapOptions>(&MapFactory::CreateSigmoidComponent<MemorySpace>));

Expand Down

0 comments on commit 75bfc3d

Please sign in to comment.