Skip to content

Commit

Permalink
Added MaxWeight mean picking + fixed the algorithm selection
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-pr committed May 4, 2020
1 parent 3ac4212 commit bb11680
Showing 1 changed file with 62 additions and 17 deletions.
79 changes: 62 additions & 17 deletions src/plugins/opencv/nodes/superpixels/mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,48 @@

namespace {

struct MeanAttrs {
virtual ~MeanAttrs() = default;
enum Mode {
kMean,
kMedian,
kMaxWeight
};

struct MeanAttrs {
dependency_graph::InAttr<possumwood::opencv::Frame> a_in;
dependency_graph::InAttr<possumwood::opencv::Frame> a_superpixels;
dependency_graph::InAttr<possumwood::Enum> a_mode;
dependency_graph::OutAttr<possumwood::opencv::Frame> a_out;

virtual std::function<float(int, int)> weightFn(dependency_graph::Values& data) const;
std::function<float(int, int)> weightFn(dependency_graph::Values& data) const;
void init(possumwood::Metadata& meta);

virtual const std::vector<std::pair<std::string, int>>& modes() const {
static std::vector<std::pair<std::string, int>> s_mode {
{"mean", kMean},
{"median", kMedian},
};

return s_mode;
}

dependency_graph::State compute(dependency_graph::Values& data) const;
};

struct WeightedMeanAttrs : public MeanAttrs {
dependency_graph::InAttr<possumwood::opencv::Frame> a_weights;

virtual std::function<float(int, int)> weightFn(dependency_graph::Values& data) const override;
std::function<float(int, int)> weightFn(dependency_graph::Values& data) const;
void init(possumwood::Metadata& meta);

virtual const std::vector<std::pair<std::string, int>>& modes() const override {
static std::vector<std::pair<std::string, int>> s_mode {
{"mean", kMean},
{"median", kMedian},
{"max weight", kMaxWeight},
};

return s_mode;
}
};

MeanAttrs s_meanAttrs;
Expand Down Expand Up @@ -59,16 +82,6 @@ void setFloat(cv::Mat& in, int row, int col, int channel, float val) {
}
}

enum Mode {
kMean,
kMedian,
};

static std::vector<std::pair<std::string, int>> s_mode {
{"mean", kMean},
{"median", kMedian},
};

class Mean {
public:
Mean() : m_val(0.0f), m_norm(0.0f) {
Expand Down Expand Up @@ -109,7 +122,8 @@ class Median {
if(m_weightsSum == 0.0f)
return 0.0f;

std::sort(m_val.begin(), m_val.end());
std::sort(m_val.begin(), m_val.end(),
[](const std::pair<float, float>& p1, const std::pair<float, float>& p2) { return p1.first < p2.first; });

auto it = m_val.begin();
float weight = 0.0f;
Expand All @@ -130,6 +144,35 @@ class Median {
float m_weightsSum;
};

class MaxWeight {
public:
MaxWeight() {
}

void add(float val, float weight) {
if(weight < 0.0f)
throw std::runtime_error("Negative weight in weights input!");

m_vals[val] += weight;
}

float operator*() const {
if(!m_vals.empty()) {
auto max = m_vals.begin();

for(auto it = m_vals.begin(); it != m_vals.end(); ++it)
if(max->second < it->second)
max = it;

return max->first;
}
return 0.0f;
}

private:
std::map<float, float> m_vals;
};

template<typename MEAN>
cv::Mat process(const cv::Mat& in, const cv::Mat& superpixels, std::function<float(int, int)> weights = [](int, int) {return 1.0f;}) {
// first of all, get the maximum index of the superpixels
Expand All @@ -139,7 +182,7 @@ cv::Mat process(const cv::Mat& in, const cv::Mat& superpixels, std::function<flo
maxIndex = std::max(maxIndex, superpixels.at<int32_t>(row, col));

// make the right sized accumulator and norm array
std::vector<std::vector<Mean>> vals(in.channels(), std::vector<Mean>(maxIndex+1, Mean()));
std::vector<std::vector<MEAN>> vals(in.channels(), std::vector<MEAN>(maxIndex+1, MEAN()));

// and accumulate the values
for(int row=0; row<in.rows; ++row)
Expand Down Expand Up @@ -199,6 +242,8 @@ dependency_graph::State MeanAttrs::compute(dependency_graph::Values& data) const
out = process<Mean>(in, superpixels, weights);
else if(data.get(a_mode).intValue() == kMedian)
out = process<Median>(in, superpixels, weights);
else if(data.get(a_mode).intValue() == kMaxWeight)
out = process<MaxWeight>(in, superpixels, weights);
else
throw std::runtime_error("Unknown mode " + data.get(a_mode).value());

Expand All @@ -210,7 +255,7 @@ dependency_graph::State MeanAttrs::compute(dependency_graph::Values& data) const
void MeanAttrs::init(possumwood::Metadata& meta) {
meta.addAttribute(a_in, "in", possumwood::opencv::Frame(), possumwood::AttrFlags::kVertical);
meta.addAttribute(a_superpixels, "superpixels", possumwood::opencv::Frame(), possumwood::AttrFlags::kVertical);
meta.addAttribute(a_mode, "mode", possumwood::Enum(s_mode.begin(), s_mode.end()));
meta.addAttribute(a_mode, "mode", possumwood::Enum(modes().begin(), modes().end()));
meta.addAttribute(a_out, "out", possumwood::opencv::Frame(), possumwood::AttrFlags::kVertical);

meta.addInfluence(a_in, a_out);
Expand Down

0 comments on commit bb11680

Please sign in to comment.