Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement get_pcg_series_parallel_decomposition #1598

Merged
merged 4 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions cmake/doctestlib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ include(aliasing)
if (FF_USE_EXTERNAL_DOCTEST)
find_package(doctest REQUIRED)
include(doctest) # import doctest_discover_tests

target_compile_definitions(
doctest::doctest
INTERFACE
DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES
)
alias_library(doctest doctest::doctest)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest)
include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake)
endif()

target_compile_definitions(
doctest::doctest
INTERFACE
DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES
)
alias_library(doctest doctest::doctest)
13 changes: 13 additions & 0 deletions lib/compiler/include/compiler/graph_optimize_result.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H

#include "compiler/graph_optimize_result.dtg.h"

namespace FlexFlow {

std::string format_as(GraphOptimizeResult const &);
std::ostream &operator<<(std::ostream &, GraphOptimizeResult const &);

} // namespace FlexFlow

#endif
7 changes: 5 additions & 2 deletions lib/compiler/include/compiler/graph_optimize_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
namespace FlexFlow {

struct GraphOptimizeState {
GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result,
float runtime);
explicit GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result,
float runtime);

GraphOptimizeResult graph_optimize_result;
float runtime;
Expand All @@ -17,6 +17,9 @@ struct GraphOptimizeState {
bool operator<(GraphOptimizeState const &other) const;
};

std::string format_as(GraphOptimizeState const &);
std::ostream &operator<<(std::ostream &, GraphOptimizeState const &);

} // namespace FlexFlow

namespace std {
Expand Down
11 changes: 11 additions & 0 deletions lib/compiler/src/compiler/graph_optimize_state.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "compiler/graph_optimize_state.h"
#include "compiler/graph_optimize_result.h"
#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h"

namespace FlexFlow {
Expand Down Expand Up @@ -54,6 +55,16 @@
return runtime < other.runtime;
}

std::string format_as(GraphOptimizeState const &st) {

Check warning on line 58 in lib/compiler/src/compiler/graph_optimize_state.cc

View check run for this annotation

Codecov / codecov/patch

lib/compiler/src/compiler/graph_optimize_state.cc#L58

Added line #L58 was not covered by tests
return fmt::format("<GraphOptimizeState graph_optimize_result={} runtime={}>",
st.graph_optimize_result,
st.runtime);

Check warning on line 61 in lib/compiler/src/compiler/graph_optimize_state.cc

View check run for this annotation

Codecov / codecov/patch

lib/compiler/src/compiler/graph_optimize_state.cc#L60-L61

Added lines #L60 - L61 were not covered by tests
}

std::ostream &operator<<(std::ostream &s, GraphOptimizeState const &st) {
return (s << fmt::to_string(st));

Check warning on line 65 in lib/compiler/src/compiler/graph_optimize_state.cc

View check run for this annotation

Codecov / codecov/patch

lib/compiler/src/compiler/graph_optimize_state.cc#L64-L65

Added lines #L64 - L65 were not covered by tests
}

} // namespace FlexFlow

namespace std {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ std::string render_preprocessed_computation_graph_for_sp_decomposition(
ComputationGraph const &cg) {
std::unordered_set<layer_guid_t> weight_and_input_layers =
filter(get_layers(cg), [&](layer_guid_t const &l) {
ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs;
ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).op_attrs;
return op_attrs.has<WeightAttrs>() || op_attrs.has<InputAttrs>();
});

Expand All @@ -41,7 +41,7 @@ std::string render_preprocessed_computation_graph_for_sp_decomposition(
return "FAKE";
}
LayerAttrs a = cg.raw_graph.at(n);
RecordFormatter r = as_dot(a.attrs);
RecordFormatter r = as_dot(a.op_attrs);

if (a.name.has_value()) {
RecordFormatter rr;
Expand Down Expand Up @@ -75,7 +75,7 @@ std::optional<SeriesParallelDecomposition>
DiGraphView preprocessed_digraph = [&] {
std::unordered_set<layer_guid_t> weight_and_input_layers =
filter(get_layers(cg), [&](layer_guid_t const &l) {
ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs;
ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).op_attrs;
return op_attrs.has<WeightAttrs>() || op_attrs.has<InputAttrs>();
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,80 @@
#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h"
#include "op-attrs/pcg_operator_attrs.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.h"
#include "utils/containers/get_only.h"
#include "utils/graph/digraph/algorithms/materialize_digraph_view.h"
#include "utils/graph/instances/adjacency_digraph.h"
#include "utils/graph/series_parallel/get_series_parallel_decomposition.h"

namespace FlexFlow {

std::optional<SeriesParallelDecomposition>
get_pcg_series_parallel_decomposition(ParallelComputationGraph const &) {
NOT_IMPLEMENTED();
get_pcg_series_parallel_decomposition(ParallelComputationGraph const &pcg) {
{
DiGraphView unpreprocessed_digraph = pcg.raw_graph;
std::optional<SeriesParallelDecomposition> unpreprocessed_sp_decomposition =
get_series_parallel_decomposition(unpreprocessed_digraph);
if (unpreprocessed_sp_decomposition.has_value()) {
return unpreprocessed_sp_decomposition.value();
}
}

auto layer_is_weight_or_input = [&](parallel_layer_guid_t const &l) {
PCGOperatorAttrs op_attrs = get_parallel_layer_attrs(pcg, l).op_attrs;
return op_attrs.has<WeightAttrs>() || op_attrs.has<InputAttrs>();
};

auto layer_is_parallel_op = [&](parallel_layer_guid_t const &l) {
PCGOperatorAttrs op_attrs = get_parallel_layer_attrs(pcg, l).op_attrs;
return is_parallel_op(op_attrs);
};

std::function<parallel_layer_guid_t(parallel_layer_guid_t const &)>
follow_to_last_parallel_op =
[&](parallel_layer_guid_t const &starting_point)
-> parallel_layer_guid_t {
assert(layer_is_weight_or_input(starting_point) ||
layer_is_parallel_op(starting_point));

std::unordered_set<parallel_layer_guid_t> successors =
get_successors(pcg, starting_point);

if (successors.size() != 1) {
return starting_point;
}

parallel_layer_guid_t successor =
get_only(get_successors(pcg, starting_point));

assert(!layer_is_weight_or_input(successor));
if (layer_is_parallel_op(successor)) {
return follow_to_last_parallel_op(successor);
} else {
return starting_point;
}
};

DiGraphView preprocessed_digraph = [&] {
std::unordered_set<parallel_layer_guid_t> weight_and_input_layers =
filter(get_parallel_layers(pcg), layer_is_weight_or_input);

std::unordered_set<parallel_layer_guid_t> par_chain_endpoints =
transform(weight_and_input_layers, follow_to_last_parallel_op);

std::unordered_set<parallel_layer_guid_t> par_chain_endpoint_successors =
get_subgraph_successors(pcg, par_chain_endpoints);

DiGraph digraph = materialize_digraph_view<AdjacencyDiGraph>(pcg.raw_graph);
for (parallel_layer_guid_t const &src : par_chain_endpoints) {
for (parallel_layer_guid_t const &dst : par_chain_endpoint_successors) {
digraph.add_edge(DirectedEdge{src.raw_graph_node, dst.raw_graph_node});
}
}

return digraph;
}();

return get_series_parallel_decomposition(preprocessed_digraph);
}

} // namespace FlexFlow
15 changes: 15 additions & 0 deletions lib/compiler/test/src/compiler/graph_optimize_result.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include "compiler/graph_optimize_result.h"

namespace FlexFlow {

std::string format_as(GraphOptimizeResult const &r) {
return fmt::format("<GraphOptimizeResult\npcg={}\nmachine_mapping={}>",
as_dot(r.pcg),
r.machine_mapping);
}

std::ostream &operator<<(std::ostream &s, GraphOptimizeResult const &r) {
return (s << fmt::to_string(r));
}

} // namespace FlexFlow
Loading
Loading