Skip to content

Commit

Permalink
Updating unity substitution set to work with new PCG interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Victor Li committed Feb 24, 2025
1 parent 1394173 commit 4fe85ed
Show file tree
Hide file tree
Showing 6 changed files with 1,064 additions and 870 deletions.
3 changes: 3 additions & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
source_up_if_exists

use flake
8 changes: 8 additions & 0 deletions .vimrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
" example search path configuration
set path=lib/runtime/**,lib/**

" set build target
" let g:target = "pcg"

" set test target
" let g:test_target = "utils-test"
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
}};
case OperatorType::REDUCTION:
return PCGOperatorAttrs{ReductionAttrs{
acc.get<nonnegative_int>(OperatorAttributeKey::PARALLEL_DEGREE),
acc.get<nonnegative_int>(OperatorAttributeKey::PARALLEL_DEGREE),
}};
case OperatorType::SOFTMAX:
return PCGOperatorAttrs{SoftmaxAttrs{
acc.get<ff_dim_t>(OperatorAttributeKey::AXIS),
}};
case OperatorType::BATCHMATMUL:
case OperatorType::SCALAR_MULTIPLY:
Expand All @@ -138,7 +142,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
case OperatorType::TANH:
case OperatorType::ELU:
case OperatorType::FLAT:
case OperatorType::SOFTMAX:
case OperatorType::BATCHNORM:
case OperatorType::CONCAT:
case OperatorType::SPLIT:
Expand Down
217 changes: 173 additions & 44 deletions lib/substitutions/src/substitutions/unity_substitution_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,42 @@ namespace FlexFlow {
std::vector<Substitution>
get_substitution_set(MachineSpecification const &resources) {
std::vector<Substitution> substitutions;
for (nonnegative_int num_dims :
for (nonnegative_int dim :
nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) {
for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources);
degree *= 2_n) {
substitutions.push_back(
create_replicate_linear_combine(num_dims, degree, true));
create_replicate_linear_combine(dim, degree, true));
substitutions.push_back(
create_replicate_linear_combine(num_dims, degree, false));
create_replicate_linear_combine(dim, degree, false));
substitutions.push_back(
create_partition_linear_combine(num_dims, degree, true));
create_partition_linear_combine(dim, degree, true));
substitutions.push_back(
create_partition_linear_combine(num_dims, degree, false));
create_partition_linear_combine(dim, degree, false));
substitutions.push_back(
create_partition_relu_combine(ff_dim_t{dim}, degree));
substitutions.push_back(
create_partition_add_combine(ff_dim_t{dim}, degree));
substitutions.push_back(create_partition_attention_combine(dim, degree));
substitutions.push_back(create_replicate_attention_reduce(dim, degree));
}
}
for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources);
degree *= 2_n) {
substitutions.push_back(create_partition_conv2d_combine(4_n, degree));
}

for (nonnegative_int partition_dim :
nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) {
for (nonnegative_int softmax_dim :
nonnegative_range(1_n, nonnegative_int{MAX_TENSOR_DIM})) {
for (nonnegative_int degree = 1_n; degree <= get_num_gpus(resources);
degree *= 2_n) {
if (partition_dim != softmax_dim) {
substitutions.push_back(create_partition_softmax_combine(
ff_dim_t{partition_dim}, ff_dim_t{softmax_dim}, degree));
}
}
}
}
substitutions.push_back(create_fuse_linear_activation(Activation::RELU));
Expand Down Expand Up @@ -173,7 +197,7 @@ Substitution create_partition_linear_combine(nonnegative_int num_dims,
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
OperatorAttributeValue{ff_dim_t{1_n}}),
OperatorAttributeValue{ff_dim_t{0_n}}),
}};
OutputGraphExprValue o_partition_input_output =
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));
Expand Down Expand Up @@ -265,13 +289,13 @@ Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
OperatorAttributeValue{ff_dim_t{1_n}}),
OperatorAttributeValue{ff_dim_t{0_n}}),
}};

OutputGraphExprValue o_partition_input_output =
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));

/*OutputOperatorAttrsAssignment replicate_weights_expr =
OutputOperatorAttrsAssignment replicate_weights_expr =
OutputOperatorAttrsAssignment{
std::nullopt,
{
Expand All @@ -283,10 +307,7 @@ Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
b.add_output_graph_node(replicate_weights_expr, {o_weight}, 1_n));

std::vector<OutputGraphExprValue> o_conv2d_inputs = {
o_partition_input_output, o_replicate_weights_output};*/

std::vector<OutputGraphExprValue> o_conv2d_inputs = {o_partition_input_output,
o_weight};
o_partition_input_output, o_replicate_weights_output};

OutputOperatorAttrsAssignment conv2d_expr = OutputOperatorAttrsAssignment{
b.pattern_node_named("conv2d"),
Expand Down Expand Up @@ -321,15 +342,16 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,

SubstitutionBuilder b;

auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());
auto [p_query_weight, o_query_weight] =
auto [p_query_input, o_query_input] =
b.add_input(tensor_attribute_pattern_match_all());
auto [p_key_input, o_key_input] =
b.add_input(tensor_attribute_pattern_match_all());
auto [p_key_weight, o_key_weight] =
auto [p_value_input, o_value_input] =
b.add_input(tensor_attribute_pattern_match_all());
auto [p_value_weight, o_value_weight] =
auto [p_weights, o_weights] =
b.add_input(tensor_attribute_pattern_match_all());
std::vector<PatternValue> p_inputs = {
p_input, p_input, p_input, p_query_weight, p_key_weight, p_value_weight};
p_query_input, p_key_input, p_value_input, p_weights};

OperatorAttributePattern attention_pattern = OperatorAttributePattern{{
op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION),
Expand All @@ -351,19 +373,35 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
OperatorAttributeValue{ff_dim_t{1_n}}),
OperatorAttributeValue{ff_dim_t{0_n}}),
}};

OutputGraphExprValue o_partition_input_output =
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));
OutputGraphExprValue o_partition_query_input_output = get_only(
b.add_output_graph_node(partition_input_expr, {o_query_input}, 1_n));

OutputGraphExprValue o_partition_key_input_output = get_only(
b.add_output_graph_node(partition_input_expr, {o_key_input}, 1_n));

OutputGraphExprValue o_partition_value_input_output = get_only(
b.add_output_graph_node(partition_input_expr, {o_value_input}, 1_n));

OutputOperatorAttrsAssignment replicate_weight_expr =
OutputOperatorAttrsAssignment{
std::nullopt,
{
set_op_type_attr(OperatorType::REPLICATE),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
}};

OutputGraphExprValue o_replicate_weight_output = get_only(
b.add_output_graph_node(replicate_weight_expr, {o_weights}, 1_n));

std::vector<OutputGraphExprValue> o_attention_inputs = {
o_partition_input_output,
o_partition_input_output,
o_partition_input_output,
o_query_weight,
o_key_weight,
o_value_weight};
o_partition_query_input_output,
o_partition_key_input_output,
o_partition_value_input_output,
o_replicate_weight_output};

OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{
b.pattern_node_named("attention"),
Expand Down Expand Up @@ -394,17 +432,19 @@ Substitution create_partition_attention_combine(nonnegative_int num_heads,

Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
nonnegative_int degree) {

SubstitutionBuilder b;

auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());
auto [p_query_weight, o_query_weight] =
auto [p_query_input, o_query_input] =
b.add_input(tensor_attribute_pattern_match_all());
auto [p_key_weight, o_key_weight] =
auto [p_key_input, o_key_input] =
b.add_input(tensor_attribute_pattern_match_all());
auto [p_value_weight, o_value_weight] =
auto [p_value_input, o_value_input] =
b.add_input(tensor_attribute_pattern_match_all());
auto [p_weights, o_weights] =
b.add_input(tensor_attribute_pattern_match_all());
std::vector<PatternValue> p_inputs = {
p_input, p_input, p_input, p_query_weight, p_key_weight, p_value_weight};
p_query_input, p_key_input, p_value_input, p_weights};

OperatorAttributePattern attention_pattern = OperatorAttributePattern{{
op_type_equals_constraint(OperatorType::MULTIHEAD_ATTENTION),
Expand All @@ -421,20 +461,40 @@ Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
OutputOperatorAttrsAssignment replicate_input_expr =
OutputOperatorAttrsAssignment{
std::nullopt,
{set_op_type_attr(OperatorType::REPLICATE),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree})}};
{
set_op_type_attr(OperatorType::REPLICATE),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
}};

OutputGraphExprValue o_replicate_input_output =
get_only(b.add_output_graph_node(replicate_input_expr, {o_input}, 1_n));
OutputGraphExprValue o_replicate_query_input_output = get_only(
b.add_output_graph_node(replicate_input_expr, {o_query_input}, 1_n));

OutputGraphExprValue o_replicate_key_input_output = get_only(
b.add_output_graph_node(replicate_input_expr, {o_key_input}, 1_n));

OutputGraphExprValue o_replicate_value_input_output = get_only(
b.add_output_graph_node(replicate_input_expr, {o_value_input}, 1_n));

OutputOperatorAttrsAssignment partition_weight_expr =
OutputOperatorAttrsAssignment{
std::nullopt,
{
set_op_type_attr(OperatorType::REPARTITION),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
OperatorAttributeValue{ff_dim_t{1_n}}),
}};

OutputGraphExprValue o_partition_weight_output = get_only(
b.add_output_graph_node(partition_weight_expr, {o_weights}, 1_n));

std::vector<OutputGraphExprValue> o_attention_inputs = {
o_replicate_input_output,
o_replicate_input_output,
o_replicate_input_output,
o_query_weight,
o_key_weight,
o_value_weight};
o_replicate_query_input_output,
o_replicate_key_input_output,
o_replicate_value_input_output,
o_partition_weight_output};

OutputOperatorAttrsAssignment attention_expr = OutputOperatorAttrsAssignment{
b.pattern_node_named("attention"),
Expand All @@ -451,14 +511,83 @@ Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
OperatorAttributeValue{degree}),
},
};
OutputGraphExprValue o_reduce_output = get_only(
b.add_output_graph_node(reduce_expr, {o_attention_output}, 1_n));
OutputGraphExprValue o_reduce_output =
get_only(b.add_output_graph_node(reduce_expr, {o_attention_output}, 1_n));

b.equate_outputs(p_attention_output, o_reduce_output);

return b.get_substitution();
}

Substitution create_partition_softmax_combine(ff_dim_t softmax_dim,
ff_dim_t partition_dim,
nonnegative_int degree) {
if (partition_dim == softmax_dim) {
throw mk_runtime_error(
fmt::format("partition dim {} must not be equal to softmax dim {}",

Check warning on line 527 in lib/substitutions/src/substitutions/unity_substitution_set.cc

View check run for this annotation

Codecov / codecov/patch

lib/substitutions/src/substitutions/unity_substitution_set.cc#L526-L527

Added lines #L526 - L527 were not covered by tests
partition_dim,
softmax_dim));

Check warning on line 529 in lib/substitutions/src/substitutions/unity_substitution_set.cc

View check run for this annotation

Codecov / codecov/patch

lib/substitutions/src/substitutions/unity_substitution_set.cc#L529

Added line #L529 was not covered by tests
}
SubstitutionBuilder b;

auto [p_input, o_input] = b.add_input(tensor_attribute_pattern_match_all());
std::vector<PatternValue> p_inputs = {p_input};

OperatorAttributePattern softmax_pattern = OperatorAttributePattern{{
op_type_equals_constraint(OperatorType::SOFTMAX),
op_attr_key_divisible_by(OperatorAttributeKey::OUT_CHANNELS, degree),
op_attr_key_divisible_by(OperatorAttributeKey::SOFTMAX_DIM,
softmax_dim.value),
}};

PatternValue p_softmax_output =
get_only(b.add_pattern_node(softmax_pattern,
p_inputs,
{tensor_attribute_pattern_match_all()},
"softmax"));

OutputOperatorAttrsAssignment partition_input_expr =
OutputOperatorAttrsAssignment{
std::nullopt,
{
set_op_type_attr(OperatorType::REPARTITION),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
OperatorAttributeValue{partition_dim}),
}};

OutputGraphExprValue o_partition_input_output =
get_only(b.add_output_graph_node(partition_input_expr, {o_input}, 1_n));

std::vector<OutputGraphExprValue> o_softmax_inputs = {
o_partition_input_output};

OutputOperatorAttrsAssignment softmax_expr = OutputOperatorAttrsAssignment{
b.pattern_node_named("softmax"),
{},
};
OutputGraphExprValue o_softmax_output =
get_only(b.add_output_graph_node(softmax_expr, o_softmax_inputs, 1_n));

OutputOperatorAttrsAssignment combine_expr = OutputOperatorAttrsAssignment{
std::nullopt,
{
set_op_type_attr(OperatorType::COMBINE),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DEGREE,
OperatorAttributeValue{degree}),
set_attr_to_constant(OperatorAttributeKey::PARALLEL_DIM,
OperatorAttributeValue{partition_dim}),
},
};
OutputGraphExprValue o_combine_output =
get_only(b.add_output_graph_node(combine_expr, {o_softmax_output}, 1_n));

b.equate_outputs(p_softmax_output, o_combine_output);

return b.get_substitution();
}

Substitution create_partition_add_combine(ff_dim_t parallel_dim,
nonnegative_int degree) {
SubstitutionBuilder b;
Expand Down
Loading

0 comments on commit 4fe85ed

Please sign in to comment.