Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 132 additions & 8 deletions src/wmtk/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,14 @@
#include <wmtk/attribute/TypedAttributeHandle.hpp>
#include <wmtk/simplex/k_ring.hpp>
#include <wmtk/simplex/link.hpp>
#include <wmtk/simplex/link_single_dimension_iterable.hpp>
#include <wmtk/simplex/utils/tuple_vector_to_homogeneous_simplex_vector.hpp>
#include <wmtk/utils/Logger.hpp>
#include <wmtk/utils/random_seed.hpp>
#include <wmtk/utils/tbb_parallel_for.hpp>

#include <polysolve/Utils.hpp>

#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wredundant-decls"
#endif
#include <tbb/parallel_for.h>
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

// #include <tbb/task_arena.h>
#include <atomic>
Expand Down Expand Up @@ -374,6 +368,136 @@ SchedulerStats Scheduler::run_operation_on_all_coloring(
return res;
}

int64_t Scheduler::color_vertices(attribute::MeshAttributeHandle& color_handle)
{
if (color_handle.primitive_type() != PrimitiveType::Vertex) {
log_and_throw_error("Color handle must be of primitive type vertex.");
}
if (color_handle.held_type() != attribute::MeshAttributeHandle::HeldType::Int64) {
log_and_throw_error("Color handle must be of type int64_t.");
}
if (color_handle.dimension() != 1) {
log_and_throw_error("Color handle must be of dimension 1.");
}

Mesh& m = color_handle.mesh();

auto acc = m.create_accessor<int64_t>(color_handle);

const auto vertices = m.get_all(PrimitiveType::Vertex);
for (const Tuple& t : vertices) {
acc.scalar_attribute(t) = -1;
}

int64_t max_color = -1;

for (const Tuple& t : vertices) {
const simplex::Simplex v(m, PrimitiveType::Vertex, t);
auto link_vertices = simplex::link_single_dimension_iterable(m, v, PrimitiveType::Vertex);

std::vector<int64_t> neighbor_colors;

for (const Tuple& neighbor_tuple : link_vertices) {
// max_neighbor = std::max(acc.const_scalar_attribute(neighbor_tuple), max_neighbor);
const int64_t c = acc.const_scalar_attribute(neighbor_tuple);
if (c >= 0) {
neighbor_colors.emplace_back(c);
}
}
const int64_t t_color = first_available_color(neighbor_colors);

acc.scalar_attribute(t) = t_color;
max_color = std::max(max_color, t_color);
}


logger().info("{} vertices with {} different colors", vertices.size(), max_color + 1);

return max_color + 1;
}

SchedulerStats Scheduler::run_operation_on_all_with_coloring(
operations::Operation& op,
attribute::MeshAttributeHandle& color_handle,
int64_t num_colors,
bool parallel_execution)
{
if (&op.mesh() != &color_handle.mesh()) {
log_and_throw_error("Operation and color handle do not belong to the same mesh!");
}
if (color_handle.primitive_type() != PrimitiveType::Vertex) {
log_and_throw_error("Color handle must be of primitive type vertex.");
}
if (color_handle.held_type() != attribute::MeshAttributeHandle::HeldType::Int64) {
log_and_throw_error("Color handle must be of type int64_t.");
}
if (color_handle.dimension() != 1) {
log_and_throw_error("Color handle must be of dimension 1.");
}

if (num_colors < 0) {
num_colors = color_vertices(color_handle);
}

Mesh& m = op.mesh();

auto color_acc = m.create_const_accessor<int64_t>(color_handle);

const auto vertices = m.get_all(op.primitive_type());

std::vector<std::vector<simplex::Simplex>> colored_vertices;
colored_vertices.resize(num_colors);

for (int64_t color = 0; color < num_colors; ++color) {
colored_vertices[color].reserve(vertices.size() / num_colors);
}

for (const Tuple& t : vertices) {
colored_vertices[color_acc.const_scalar_attribute(t)].emplace_back(
m,
op.primitive_type(),
t);
}

SchedulerStats res;

for (auto& one_color_vertices : colored_vertices) {
std::atomic_int suc_cnt = 0;
std::atomic_int fail_cnt = 0;

if (parallel_execution) {
tbb::parallel_for(
tbb::blocked_range<int64_t>(0, one_color_vertices.size()),
[&](tbb::blocked_range<int64_t> r) {
for (int64_t k = r.begin(); k < r.end(); ++k) {
auto mods = op(one_color_vertices[k]);
if (mods.empty()) {
fail_cnt++;
} else {
suc_cnt++;
}
}
});
} else {
for (const simplex::Simplex& s : one_color_vertices) {
auto mods = op(s);
if (mods.empty()) {
fail_cnt++;
} else {
suc_cnt++;
}
}
}

res.m_num_op_success = suc_cnt;
res.m_num_op_fail = fail_cnt;
}

m_stats += res;

return res;
}

void Scheduler::log(size_t total)
{
log(m_stats, total);
Expand Down
25 changes: 25 additions & 0 deletions src/wmtk/Scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,31 @@ class Scheduler
operations::Operation& op,
const TypedAttributeHandle<int64_t>& color_handle);

/**
* @brief Add vertex colors to perform vertex optimization in parallel.
*
* @param color_handle A vertex int64_t scalar attribute representing the color.
* @return number of colors
*/
int64_t color_vertices(attribute::MeshAttributeHandle& color_handle);

/**
* @brief Run op on all vertices in parallel using the coloring.
*
* Potential race conditions!!!
* Attribute transfer applies changes to the entire closed star, which leads to race conditions
* for any transfer to edges or vertices. Do not use this function in that case!!!
*
* @param op The operation, must be of primitive type vertex.
* @param color_handle The attribute holding the vertex int64_t scalar coloring scheme.
* @param num_colors The number of different colors. If negative, the coloring is initialized.
*/
SchedulerStats run_operation_on_all_with_coloring(
operations::Operation& op,
attribute::MeshAttributeHandle& color_handle,
int64_t num_colors = -1,
bool parallel_execution = true);

const SchedulerStats& stats() const { return m_stats; }

void set_update_frequency(std::optional<size_t>&& freq = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "AttributeTransferStrategyBase.hpp"
#include <wmtk/Mesh.hpp>
#include <wmtk/simplex/neighbors_single_dimension.hpp>
#include <wmtk/utils/tbb_parallel_for.hpp>

#include <wmtk/simplex/utils/unique_homogeneous_simplices.hpp>

Expand All @@ -16,13 +17,23 @@ const Mesh& AttributeTransferStrategyBase::mesh() const
return const_cast<const Mesh&>(const_cast<AttributeTransferStrategyBase*>(this)->mesh());
}

void AttributeTransferStrategyBase::run_on_all() const
void AttributeTransferStrategyBase::run_on_all(bool parallel) const
{
const PrimitiveType pt = m_handle.primitive_type();
auto tuples = m_handle.mesh().get_all(pt);

for (const Tuple& t : tuples) {
run(simplex::Simplex(m_handle.mesh(), pt, t));
if (parallel) {
tbb::parallel_for(
tbb::blocked_range<int64_t>(0, tuples.size()),
[&](tbb::blocked_range<int64_t> r) {
for (int64_t k = r.begin(); k < r.end(); ++k) {
run(simplex::Simplex(m_handle.mesh(), pt, tuples[k]));
}
});
} else {
for (const Tuple& t : tuples) {
run(simplex::Simplex(m_handle.mesh(), pt, t));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class AttributeTransferStrategyBase : public AttributeTransferEdge

// runs the transfer on every simplex - good for initializing an attribute that will be
// managed by transfer
void run_on_all() const;
void run_on_all(bool parallel = false) const;

private:
attribute::MeshAttributeHandle m_handle;
Expand Down
2 changes: 2 additions & 0 deletions src/wmtk/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ set(SRC_FILES
orient.hpp
orient.cpp

tbb_parallel_for.hpp

TupleInspector.cpp
triangle_areas.hpp
triangle_areas.cpp
Expand Down
10 changes: 10 additions & 0 deletions src/wmtk/utils/tbb_parallel_for.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wredundant-decls"
#endif
#include <tbb/parallel_for.h>
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
Loading