Skip to content

Commit

Permalink
Abstracted Mean computation, to allow for a new weighted mean node + …
Browse files Browse the repository at this point in the history
…first version of superpixel depth demo
  • Loading branch information
martin-pr committed May 3, 2020
1 parent d971295 commit 3ac4212
Show file tree
Hide file tree
Showing 2 changed files with 1,133 additions and 22 deletions.
109 changes: 87 additions & 22 deletions src/plugins/opencv/nodes/superpixels/mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,28 @@
namespace {

struct MeanAttrs {
virtual ~MeanAttrs() = default;

dependency_graph::InAttr<possumwood::opencv::Frame> a_in;
dependency_graph::InAttr<possumwood::opencv::Frame> a_superpixels;
dependency_graph::InAttr<possumwood::Enum> a_mode;
dependency_graph::OutAttr<possumwood::opencv::Frame> a_out;

virtual std::function<float(int, int)> weightFn(dependency_graph::Values& data) const;
void init(possumwood::Metadata& meta);

dependency_graph::State compute(dependency_graph::Values& data) const;
};

struct WeightedMeanAttrs : public MeanAttrs {
dependency_graph::InAttr<possumwood::opencv::Frame> a_weights;

virtual std::function<float(int, int)> weightFn(dependency_graph::Values& data) const override;
void init(possumwood::Metadata& meta);
};

MeanAttrs s_meanAttrs;
WeightedMeanAttrs s_weightedMeanAttrs;

float getFloat(const cv::Mat& in, int row, int col, int channel) {
switch(in.depth()) {
Expand Down Expand Up @@ -61,11 +74,12 @@ class Mean {
Mean() : m_val(0.0f), m_norm(0.0f) {
}

Mean& operator += (float val) {
m_val += val;
m_norm += 1.0f;
void add(float val, float weight) {
if(weight < 0.0f)
throw std::runtime_error("Negative weight in weights input!");

return *this;
m_val += val * weight;
m_norm += weight;
}

float operator*() const {
Expand All @@ -80,27 +94,44 @@ class Mean {

class Median {
public:
Median& operator += (float val) {
m_val.push_back(val);
Median() : m_weightsSum(0.0f) {
}

return *this;
void add(float val, float weight) {
if(weight < 0.0f)
throw std::runtime_error("Negative weight in weights input!");

m_val.push_back(std::make_pair(val, weight));
m_weightsSum += weight;
}

float operator*() {
if(m_val.empty())
if(m_weightsSum == 0.0f)
return 0.0f;

std::sort(m_val.begin(), m_val.end());

return m_val[m_val.size()/2];
auto it = m_val.begin();
float weight = 0.0f;
weight += it->second;

while(weight < m_weightsSum/2.0f && it != m_val.end()-1) {
++it;

if(it != m_val.end())
weight += it->second;
}

return it->first;
}

private:
std::vector<float> m_val;
std::vector<std::pair<float, float>> m_val;
float m_weightsSum;
};

template<typename MEAN>
cv::Mat process(const cv::Mat& in, const cv::Mat& superpixels) {
cv::Mat process(const cv::Mat& in, const cv::Mat& superpixels, std::function<float(int, int)> weights = [](int, int) {return 1.0f;}) {
// first of all, get the maximum index of the superpixels
int32_t maxIndex = 0;
for(int row=0; row<superpixels.rows; ++row)
Expand All @@ -116,7 +147,7 @@ cv::Mat process(const cv::Mat& in, const cv::Mat& superpixels) {
const int32_t index = superpixels.at<int32_t>(row, col);

for(int c=0;c<in.channels();++c)
vals[c][index] += getFloat(in, row, col, c);
vals[c][index].add(getFloat(in, row, col, c), weights(row, col));
}

cv::Mat out = cv::Mat::zeros(in.rows, in.cols, in.type());
Expand All @@ -132,25 +163,46 @@ cv::Mat process(const cv::Mat& in, const cv::Mat& superpixels) {
return out;
}

dependency_graph::State compute(dependency_graph::Values& data, MeanAttrs& attrs) {
const cv::Mat& in = *data.get(attrs.a_in);
const cv::Mat& superpixels = *data.get(attrs.a_superpixels);
std::function<float(int, int)> MeanAttrs::weightFn(dependency_graph::Values& data) const {
return [](int, int) { return 1.0f; };
}

std::function<float(int, int)> WeightedMeanAttrs::weightFn(dependency_graph::Values& data) const {
const cv::Mat& in = *data.get(a_in);
const cv::Mat& weights = *data.get(a_weights);

if(in.rows != weights.rows || in.cols != weights.cols)
throw std::runtime_error("Input and weights size have to match.");

if(weights.type() != CV_32FC1)
throw std::runtime_error("Only CV_32FC1 type supported on the weights input!");

return [&](int row, int col) {
return weights.at<float>(row, col);
};
}

dependency_graph::State MeanAttrs::compute(dependency_graph::Values& data) const {
const cv::Mat& in = *data.get(a_in);
const cv::Mat& superpixels = *data.get(a_superpixels);

if(in.rows != superpixels.rows || in.cols != superpixels.cols)
throw std::runtime_error("Input and superpixel size have to match.");

if(superpixels.type() != CV_32SC1)
throw std::runtime_error("Only CV_32SC1 type supported on the superpixels input!");

std::function<float(int, int)> weights = weightFn(data);

cv::Mat out;
if(data.get(attrs.a_mode).intValue() == kMean)
out = process<Mean>(in, superpixels);
else if(data.get(attrs.a_mode).intValue() == kMedian)
out = process<Median>(in, superpixels);
if(data.get(a_mode).intValue() == kMean)
out = process<Mean>(in, superpixels, weights);
else if(data.get(a_mode).intValue() == kMedian)
out = process<Median>(in, superpixels, weights);
else
throw std::runtime_error("Unknown mode " + data.get(attrs.a_mode).value());
throw std::runtime_error("Unknown mode " + data.get(a_mode).value());

data.set(attrs.a_out, possumwood::opencv::Frame(out));
data.set(a_out, possumwood::opencv::Frame(out));

return dependency_graph::State();
}
Expand All @@ -166,10 +218,23 @@ void MeanAttrs::init(possumwood::Metadata& meta) {
meta.addInfluence(a_mode, a_out);

meta.setCompute([this](dependency_graph::Values& data) {
return compute(data, *this);
return compute(data);
});
}

void WeightedMeanAttrs::init(possumwood::Metadata& meta) {
MeanAttrs::init(meta);

meta.addAttribute(a_weights, "weights", possumwood::opencv::Frame(), possumwood::AttrFlags::kVertical);

meta.addInfluence(a_weights, a_out);

meta.setCompute([this](dependency_graph::Values& data) {
return compute(data);
});
}

possumwood::NodeImplementation s_impl("opencv/superpixels/mean", [](possumwood::Metadata& meta) { s_meanAttrs.init(meta); });
possumwood::NodeImplementation s_implW("opencv/superpixels/weighted_mean", [](possumwood::Metadata& meta) { s_weightedMeanAttrs.init(meta); });

}
Loading

0 comments on commit 3ac4212

Please sign in to comment.