Skip to content

Enhance/const correct(er) dataobjects #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/algorithms/public/MLP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class MLP
}

void processFrame(RealVectorView in, RealVectorView out, index startLayer,
index endLayer)
index endLayer) const
{
using namespace _impl;
using namespace Eigen;
Expand All @@ -113,13 +113,13 @@ class MLP
out <<= asFluid(tmpOut);
}

void forward(Eigen::Ref<ArrayXXd> in, Eigen::Ref<ArrayXXd> out)
void forward(Eigen::Ref<ArrayXXd> in, Eigen::Ref<ArrayXXd> out) const
{
forward(in, out, 0, asSigned(mLayers.size()));
}

void forward(Eigen::Ref<ArrayXXd> in, Eigen::Ref<ArrayXXd> out,
index startLayer, index endLayer)
index startLayer, index endLayer) const
{
if (startLayer >= asSigned(mLayers.size()) ||
endLayer > asSigned(mLayers.size()))
Expand All @@ -137,7 +137,7 @@ class MLP
out = output;
}

void backward(Eigen::Ref<ArrayXXd> out)
void backward(Eigen::Ref<ArrayXXd> out)
{
index nRows = out.rows();
ArrayXXd chain =
Expand Down
24 changes: 12 additions & 12 deletions include/algorithms/public/UMAP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class UMAP
return out;
}

DataSet transform(DataSet& in, index maxIter = 200, double learningRate = 1.0)
DataSet transform(DataSet& in, index maxIter = 200, double learningRate = 1.0) const
{
if (!mInitialized) return DataSet();
SparseMatrixXd knnGraph(in.size(), mEmbedding.rows());
Expand All @@ -158,7 +158,7 @@ class UMAP
}


void transformPoint(RealVectorView in, RealVectorView out)
void transformPoint(RealVectorView in, RealVectorView out) const
{
if (!mInitialized) return;
SparseMatrixXd knnGraph(1, mEmbedding.rows());
Expand All @@ -185,7 +185,7 @@ class UMAP

private:
template <typename F>
void traverseGraph(const SparseMatrixXd& graph, F func)
void traverseGraph(const SparseMatrixXd& graph, F func) const
{
for (index i = 0; i < graph.outerSize(); i++)
{
Expand All @@ -204,7 +204,7 @@ class UMAP
}

ArrayXd findSigma(index k, Ref<ArrayXXd> dists, index maxIter = 64,
double tolerance = 1e-5)
double tolerance = 1e-5) const
{
using namespace std;
double target = log2(k);
Expand Down Expand Up @@ -242,7 +242,7 @@ class UMAP
}

void computeHighDimProb(const Ref<ArrayXXd>& dists, const Ref<ArrayXd>& sigma,
SparseMatrixXd& graph)
SparseMatrixXd& graph) const
{
traverseGraph(graph, [&](auto it) {
it.valueRef() =
Expand All @@ -263,7 +263,7 @@ class UMAP
}

void makeGraph(const DataSet& in, index k, SparseMatrixXd& graph,
Ref<ArrayXXd> dists, bool discardFirst)
Ref<ArrayXXd> dists, bool discardFirst) const
{
graph.reserve(in.size() * k);
auto data = in.getData();
Expand Down Expand Up @@ -298,7 +298,7 @@ class UMAP
}

void getGraphIndices(const SparseMatrixXd& graph, Ref<ArrayXi> rowIndices,
Ref<ArrayXi> colIndices)
Ref<ArrayXi> colIndices) const
{
index p = 0;
traverseGraph(graph, [&](auto it) {
Expand All @@ -309,7 +309,7 @@ class UMAP
}

void computeEpochsPerSample(const SparseMatrixXd& graph,
Ref<ArrayXd> epochsPerSample)
Ref<ArrayXd> epochsPerSample) const
{
index p = 0;
double maxVal = graph.coeffs().maxCoeff();
Expand All @@ -321,7 +321,7 @@ class UMAP
void optimizeLayout(Ref<ArrayXXd> embedding, Ref<ArrayXXd> reference,
Ref<ArrayXi> embIndices, Ref<ArrayXi> refIndices,
Ref<ArrayXd> epochsPerSample, bool updateReference,
double learningRate, index maxIter, double gamma = 1.0)
double learningRate, index maxIter, double gamma = 1.0) const
{
using namespace std;
double alpha = learningRate;
Expand Down Expand Up @@ -385,7 +385,7 @@ class UMAP
}

ArrayXXd initTransformEmbedding(const SparseMatrixXd& graph,
Ref<ArrayXXd> reference, index N)
Ref<const ArrayXXd> reference, index N) const
{
ArrayXXd embedding = ArrayXXd::Zero(N, reference.cols());
traverseGraph(graph, [&](auto it) {
Expand All @@ -394,7 +394,7 @@ class UMAP
return embedding;
}

void normalizeRows(const SparseMatrixXd& graph)
void normalizeRows(const SparseMatrixXd& graph) const
{
ArrayXd sums = ArrayXd::Zero(graph.innerSize());
traverseGraph(graph, [&](auto it) { sums(it.row()) += it.value(); });
Expand All @@ -406,7 +406,7 @@ class UMAP
KDTree mTree;
index mK;
VectorXd mAB;
ArrayXXd mEmbedding;
mutable ArrayXXd mEmbedding;
bool mInitialized{false};
};
}// namespace algorithm
Expand Down
6 changes: 3 additions & 3 deletions include/algorithms/util/NNLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class NNLayer

index outputSize() const { return mWeights.cols(); }

void forward(Eigen::Ref<MatrixXd> in, Eigen::Ref<MatrixXd> out)
void forward(Eigen::Ref<MatrixXd> in, Eigen::Ref<MatrixXd> out) const
{
mInput = in;
MatrixXd WT = mWeights.transpose();
Expand Down Expand Up @@ -114,8 +114,8 @@ class NNLayer
MatrixXd mPrevWeightsUpdate;
VectorXd mPrevBiasesUpdate;

MatrixXd mInput;
MatrixXd mOutput;
mutable MatrixXd mInput;
mutable MatrixXd mOutput;
};
} // namespace algorithm
} // namespace fluid
4 changes: 2 additions & 2 deletions include/clients/common/SharedClientUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class SharedClientRef

SharedClientRef() {}
SharedClientRef(const char* name) : mName{name} {}
WeakPointer get() { return {SharedType::lookup(mName)}; }
WeakPointer get() const { return {SharedType::lookup(mName)}; }
void set(const char* name) { mName = std::string(name); }
const char* name() { return mName.c_str(); }
const char* name() const { return mName.c_str(); }

// Supporting machinery for making new parameter types

Expand Down
4 changes: 2 additions & 2 deletions include/clients/nrt/ClientInputChecks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class InBufferCheck : public ClientInputCheck
{
public:
InBufferCheck(index size) : mInputSize(size){};
bool checkInputs(BufferAdaptor* inputPtr)
bool checkInputs(const BufferAdaptor* inputPtr)
{
if (!inputPtr)
{
Expand Down Expand Up @@ -61,7 +61,7 @@ class InOutBuffersCheck : public InBufferCheck

public:
using InBufferCheck::InBufferCheck;
bool checkInputs(BufferAdaptor* inputPtr, BufferAdaptor* outputPtr)
bool checkInputs(const BufferAdaptor* inputPtr, BufferAdaptor* outputPtr)
{
if (!InBufferCheck::checkInputs(inputPtr)) { return false; }
if (!outputPtr)
Expand Down
8 changes: 4 additions & 4 deletions include/clients/nrt/DataClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class DataClient
public:
using string = std::string;

MessageResult<index> size() { return mAlgorithm.size(); }
MessageResult<index> size() const { return mAlgorithm.size(); }

MessageResult<index> dims() { return mAlgorithm.dims(); }
MessageResult<index> dims() const { return mAlgorithm.dims(); }

MessageResult<void> clear()
{
Expand Down Expand Up @@ -80,8 +80,8 @@ class DataClient
}
}

bool initialized() { return mAlgorithm.initialized(); }
T& algorithm() { return mAlgorithm; }
bool initialized() const { return mAlgorithm.initialized(); }
T const& algorithm() const { return mAlgorithm; }
protected:
T mAlgorithm;
};
Expand Down
23 changes: 13 additions & 10 deletions include/clients/nrt/DataSetClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DataSetClient : public FluidBaseClient,
public:
using string = std::string;
using BufferPtr = std::shared_ptr<BufferAdaptor>;
using InputBufferPtr = std::shared_ptr<const BufferAdaptor>;
using DataSet = FluidDataSet<string, double, 1>;
using LabelSet = FluidDataSet<string, string, 1>;

Expand All @@ -61,11 +62,11 @@ class DataSetClient : public FluidBaseClient,

DataSetClient(ParamSetViewType& p) : mParams(p) {}

MessageResult<void> addPoint(string id, BufferPtr data)
MessageResult<void> addPoint(string id, InputBufferPtr data)
{
DataSet& dataset = mAlgorithm;
if (!data) return Error(NoBuffer);
BufferAdaptor::Access buf(data.get());
BufferAdaptor::ReadAccess buf(data.get());
if (!buf.exists()) return Error(InvalidBuffer);
if (buf.numFrames() == 0) return Error(EmptyBuffer);
if (dataset.size() == 0)
Expand Down Expand Up @@ -101,23 +102,23 @@ class DataSetClient : public FluidBaseClient,
}
}

MessageResult<void> updatePoint(string id, BufferPtr data)
MessageResult<void> updatePoint(string id, InputBufferPtr data)
{
if (!data) return Error(NoBuffer);
BufferAdaptor::Access buf(data.get());
BufferAdaptor::ReadAccess buf(data.get());
if (!buf.exists()) return Error(InvalidBuffer);
if (buf.numFrames() < mAlgorithm.dims()) return Error(WrongPointSize);
RealVector point(mAlgorithm.dims());
point <<= buf.samps(0, mAlgorithm.dims(), 0);
return mAlgorithm.update(id, point) ? OK() : Error(PointNotFound);
}

MessageResult<void> setPoint(string id, BufferPtr data)
MessageResult<void> setPoint(string id, InputBufferPtr data)
{
if (!data) return Error(NoBuffer);

{ // restrict buffer lock to this scope in case addPoint is called
BufferAdaptor::Access buf(data.get());
BufferAdaptor::ReadAccess buf(data.get());
if (!buf.exists()) return Error(InvalidBuffer);
if (buf.numFrames() < mAlgorithm.dims()) return Error(WrongPointSize);
RealVector point(mAlgorithm.dims());
Expand All @@ -133,7 +134,7 @@ class DataSetClient : public FluidBaseClient,
return mAlgorithm.remove(id) ? OK() : Error(PointNotFound);
}

MessageResult<void> merge(SharedClientRef<DataSetClient> datasetClient,
MessageResult<void> merge(SharedClientRef<const DataSetClient> datasetClient,
bool overwrite)
{
auto datasetClientPtr = datasetClient.get().lock();
Expand All @@ -154,11 +155,11 @@ class DataSetClient : public FluidBaseClient,
}

MessageResult<void>
fromBuffer(BufferPtr data, bool transpose,
SharedClientRef<labelset::LabelSetClient> labels)
fromBuffer(InputBufferPtr data, bool transpose,
SharedClientRef<const labelset::LabelSetClient> labels)
{
if (!data) return Error(NoBuffer);
BufferAdaptor::Access buf(data.get());
BufferAdaptor::ReadAccess buf(data.get());
if (!buf.exists()) return Error(InvalidBuffer);
auto bufView = transpose ? buf.allFrames() : buf.allFrames().transpose();
if (auto labelsPtr = labels.get().lock())
Expand Down Expand Up @@ -256,6 +257,8 @@ class DataSetClient : public FluidBaseClient,
} // namespace dataset

using DataSetClientRef = SharedClientRef<dataset::DataSetClient>;
using InputDataSetClientRef = SharedClientRef<const dataset::DataSetClient>;

using NRTThreadedDataSetClient =
NRTThreadingAdaptor<typename DataSetClientRef::SharedType>;

Expand Down
6 changes: 3 additions & 3 deletions include/clients/nrt/DataSetQueryClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class DataSetQueryClient : public FluidBaseClient, OfflineIn, OfflineOut
}


MessageResult<void> transform(DataSetClientRef sourceClient,
MessageResult<void> transform(InputDataSetClientRef sourceClient,
DataSetClientRef destClient)
{
if (mAlgorithm.numColumns() <= 0) return Error("No columns");
Expand All @@ -118,8 +118,8 @@ class DataSetQueryClient : public FluidBaseClient, OfflineIn, OfflineOut
return OK();
}

MessageResult<void> transformJoin(DataSetClientRef source1Client,
DataSetClientRef source2Client,
MessageResult<void> transformJoin(InputDataSetClientRef source1Client,
InputDataSetClientRef source2Client,
DataSetClientRef destClient)
{
auto src1Ptr = source1Client.get().lock();
Expand Down
2 changes: 1 addition & 1 deletion include/clients/nrt/GridClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class GridClient : public FluidBaseClient, OfflineIn, OfflineOut, ModelObject

GridClient(ParamSetViewType& p) : mParams(p) {}

MessageResult<void> fitTransform(DataSetClientRef sourceClient,
MessageResult<void> fitTransform(InputDataSetClientRef sourceClient,
DataSetClientRef destClient)
{
auto srcPtr = sourceClient.get().lock();
Expand Down
21 changes: 11 additions & 10 deletions include/clients/nrt/KDTreeClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class KDTreeClient : public FluidBaseClient,
public:
using string = std::string;
using BufferPtr = std::shared_ptr<BufferAdaptor>;
using InputBufferPtr = std::shared_ptr<const BufferAdaptor>;
using StringVector = FluidTensor<string, 1>;
using ParamDescType = decltype(KDTreeParams);

Expand Down Expand Up @@ -63,7 +64,7 @@ class KDTreeClient : public FluidBaseClient,
return {};
}

MessageResult<void> fit(DataSetClientRef datasetClient)
MessageResult<void> fit(InputDataSetClientRef datasetClient)
{
mDataSetClient = datasetClient;
auto datasetClientPtr = mDataSetClient.get().lock();
Expand All @@ -74,7 +75,7 @@ class KDTreeClient : public FluidBaseClient,
return OK();
}

MessageResult<StringVector> kNearest(BufferPtr data) const
MessageResult<StringVector> kNearest(InputBufferPtr data) const
{
index k = get<kNumNeighbors>();
if (k > mAlgorithm.size()) return Error<StringVector>(SmallDataSet);
Expand All @@ -92,7 +93,7 @@ class KDTreeClient : public FluidBaseClient,
return result;
}

MessageResult<RealVector> kNearestDist(BufferPtr data) const
MessageResult<RealVector> kNearestDist(InputBufferPtr data) const
{
// TODO: refactor with kNearest
index k = get<kNumNeighbors>();
Expand Down Expand Up @@ -126,22 +127,22 @@ class KDTreeClient : public FluidBaseClient,
makeMessage("read", &KDTreeClient::read));
}

DataSetClientRef getDataSet() { return mDataSetClient; }
InputDataSetClientRef getDataSet() const { return mDataSetClient; }

const algorithm::KDTree& algorithm() { return mAlgorithm; }
const algorithm::KDTree& algorithm() const { return mAlgorithm; }

private:
DataSetClientRef mDataSetClient;
InputDataSetClientRef mDataSetClient;
};

using KDTreeRef = SharedClientRef<KDTreeClient>;
using KDTreeRef = SharedClientRef<const KDTreeClient>;

constexpr auto KDTreeQueryParams = defineParameters(
KDTreeRef::makeParam("tree", "KDTree"),
LongParam("numNeighbours", "Number of Nearest Neighbours", 1),
FloatParam("radius", "Maximum distance", 0, Min(0)),
DataSetClientRef::makeParam("dataSet", "DataSet Name"),
BufferParam("inputPointBuffer", "Input Point Buffer"),
InputDataSetClientRef::makeParam("dataSet", "DataSet Name"),
InputBufferParam("inputPointBuffer", "Input Point Buffer"),
BufferParam("predictionBuffer", "Prediction Buffer"));

class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut
Expand Down Expand Up @@ -238,7 +239,7 @@ class KDTreeQuery : public FluidBaseClient, ControlIn, ControlOut

private:
RealVector mRTBuffer;
DataSetClientRef mDataSetClient;
InputDataSetClientRef mDataSetClient;
};

} // namespace kdtree
Expand Down
Loading