Skip to content

Commit

Permalink
Update linear_combination_generic.h (NVIDIA#472)
Browse files Browse the repository at this point in the history
add `skip_elementwise_` to support serial splitk in linear_combination_generic.h`
  • Loading branch information
hwu36 authored Jun 28, 2022
1 parent dae6b68 commit e45e773
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions include/cutlass/epilogue/thread/linear_combination_generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class LinearCombinationGeneric {

ElementCompute alpha_;
ElementCompute beta_;
bool skip_elementwise_;

public:

Expand All @@ -135,6 +136,7 @@ class LinearCombinationGeneric {

alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
skip_elementwise_ = false;
}

/// Returns true if source is needed
Expand All @@ -155,6 +157,10 @@ class LinearCombinationGeneric {
if (k_partition) {
beta_ = ElementCompute(1);
}

if (k_partition != k_partition_count - 1) {
skip_elementwise_ = true;
}
}

/// Computes linear scaling: D = alpha * accumulator + beta * source
Expand Down Expand Up @@ -188,7 +194,7 @@ class LinearCombinationGeneric {
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
}

intermediate = activation(intermediate);
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);

// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
Expand Down Expand Up @@ -219,7 +225,7 @@ class LinearCombinationGeneric {
intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum
}

intermediate = activation(intermediate);
intermediate = skip_elementwise_ ? intermediate : activation(intermediate);

// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
Expand Down

0 comments on commit e45e773

Please sign in to comment.