Skip to content

Commit

Permalink
[GPU] Adding GeGLU (openvinotoolkit#24970)
Browse files Browse the repository at this point in the history
### Details:
- This PR extends the current `SwiGLU` primitive to support `GeGLU` that
has Gelu activations instead of Swish.
 - This GeGLU patterns can be found in stable diffusion models.

### Tickets:
 - 143486
  • Loading branch information
e-ddykim authored Jun 17, 2024
1 parent d25fe2a commit 44e68a2
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 20 deletions.
18 changes: 17 additions & 1 deletion src/plugins/intel_gpu/include/intel_gpu/op/swiglu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,31 @@ namespace intel_gpu {
namespace op {

/// \brief Operator performing Swish Gated Linear Unit Activation
/// This operation performs gated linear unit activation that combines swish activation function
/// This operation performs gated linear unit activation that combines swish or gelu activation function
class SwiGLU : public ov::op::Op {
public:
OPENVINO_OP("SwiGLU", "gpu_opset");

enum GluType {
Swish = 0,
Gelu,
Gelu_Tanh
};

SwiGLU() = default;
/// \brief Constructs an SwiGLU operation.
///
/// \param data Input tensor with data
/// \param axis The index of an axis in "data" along which to perform the split
/// \param split_lenghts A list containing the sizes of each output tensor along the split "axis"
/// \param glu_type GLU type, one of Swish, Gelu and Gelu_Tanh
/// \param split_to_glu_idx Output index of variadic split, which is connected to GLU
/// \param output_type Output element type
SwiGLU(const Output<Node>& data,
int64_t axis,
int64_t split_lengths,
const GluType glu_type,
const size_t split_to_glu_idx,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;
Expand All @@ -36,13 +46,19 @@ class SwiGLU : public ov::op::Op {

int64_t get_axis() const { return m_axis; }
int64_t get_split_lengths() const { return m_split_lengths; }
GluType get_glu_type() const { return m_glu_type; }
size_t get_split_to_glu_idx() const { return m_split_to_glu_idx; }

void set_axis(int64_t axis) { m_axis = axis; }
void set_split_lengths(int64_t split_lengths) { m_split_lengths = split_lengths; }
void set_glu_type(GluType glu_type) { m_glu_type = glu_type; }
void set_split_to_glu_idx(size_t split_to_glu_idx) { m_split_to_glu_idx = split_to_glu_idx; }

private:
int64_t m_axis = 0;
int64_t m_split_lengths = 0;
GluType m_glu_type = GluType::Swish;
size_t m_split_to_glu_idx = 0;
ov::element::Type m_output_type;
};

Expand Down
18 changes: 16 additions & 2 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/swiglu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

#pragma once
#include "primitive.hpp"
#include "intel_gpu/op/swiglu.hpp"

namespace cldnn {

/// @brief Swish Gated Linear Unit Activation primitive
/// @details Performs gated linear unit activation that combines swish activation function
/// @details Performs gated linear unit activation that combines swish or gelu activation function
struct swiglu : public primitive_base<swiglu> {
CLDNN_DECLARE_PRIMITIVE(swiglu);

Expand All @@ -24,21 +25,29 @@ struct swiglu : public primitive_base<swiglu> {
const input_info& input,
const int64_t& axis,
const int64_t& split_lengths,
const ov::intel_gpu::op::SwiGLU::GluType glu_type,
const size_t split_to_glu_idx,
const tensor output_size,
const padding& output_padding = padding())
: primitive_base(id, {input}, {output_padding}),
axis(axis),
split_lengths(split_lengths),
glu_type(glu_type),
split_to_glu_idx(split_to_glu_idx),
output_size(output_size) {}

int64_t axis = 0;
int64_t split_lengths = 0;
ov::intel_gpu::op::SwiGLU::GluType glu_type = ov::intel_gpu::op::SwiGLU::GluType::Swish;
size_t split_to_glu_idx = 0;
tensor output_size;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, axis);
seed = hash_combine(seed, split_lengths);
seed = hash_combine(seed, glu_type);
seed = hash_combine(seed, split_to_glu_idx);
return seed;
}

Expand All @@ -47,21 +56,26 @@ struct swiglu : public primitive_base<swiglu> {
return false;

auto rhs_casted = downcast<const swiglu>(rhs);
return axis == rhs_casted.axis && split_lengths == rhs_casted.split_lengths;
return axis == rhs_casted.axis && split_lengths == rhs_casted.split_lengths &&
glu_type == rhs_casted.glu_type && split_to_glu_idx == rhs_casted.split_to_glu_idx;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<swiglu>::save(ob);
ob << axis;
ob << split_lengths;
ob << output_size;
ob << make_data(&glu_type, sizeof(glu_type));
ob << split_to_glu_idx;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<swiglu>::load(ib);
ib >> axis;
ib >> split_lengths;
ib >> output_size;
ib >> make_data(&glu_type, sizeof(glu_type));
ib >> split_to_glu_idx;
}
};
} // namespace cldnn
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/swiglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ struct swiglu_impl : typed_primitive_impl_ocl<swiglu> {
auto rank = impl_param.get_input_layout(0).get_partial_shape().rank();
params.axis = ov::util::normalize(primitive->axis, rank.get_length());
params.split_length = primitive->split_lengths;
params.glu_type = primitive->glu_type;
params.split_to_glu_idx = static_cast<int32_t>(primitive->split_to_glu_idx);

return params;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,34 @@ KERNEL(swiglu_gpu_ref)(

#if OUTPUT_DIMS == 5
const uint output_idx = OUTPUT_GET_INDEX(b, f, z, y, x);
const uint gate_idx = INPUT0_GET_INDEX(b, f, z, y, x);
const uint input_idx = INPUT0_GET_INDEX(b, f, z, y, x) + SPLIT_LENGTH;
#if SPLIT_TO_GLU_IDX == 0
const uint gate_idx = INPUT0_GET_INDEX(b, f, z, y, x);
const uint input_idx = gate_idx + SPLIT_LENGTH;
#else
const uint input_idx = INPUT0_GET_INDEX(b, f, z, y, x);
const uint gate_idx = input_idx + SPLIT_LENGTH;
#endif
#else // 2D spatial
const uint output_idx = OUTPUT_GET_INDEX(b, f, y, x);
const uint gate_idx = INPUT0_GET_INDEX(b, f, y, x);
const uint input_idx = INPUT0_GET_INDEX(b, f, y, x) + SPLIT_LENGTH;
#if SPLIT_TO_GLU_IDX == 0
const uint gate_idx = INPUT0_GET_INDEX(b, f, y, x);
const uint input_idx = gate_idx + SPLIT_LENGTH;
#else
const uint input_idx = INPUT0_GET_INDEX(b, f, y, x);
const uint gate_idx = input_idx + SPLIT_LENGTH;
#endif
#endif

ACCUMULATOR_TYPE res = ACCUMULATOR_VAL_ZERO;

res = (ACCUMULATOR_TYPE)input[gate_idx];
res /= ACCUMULATOR_VAL_ONE + exp(-(ACCUMULATOR_VAL_ONE * res));
#if GLU_TYPE == 0 // Swish
res /= ACCUMULATOR_VAL_ONE + exp(-(ACCUMULATOR_VAL_ONE * res));
#elif GLU_TYPE == 1 // Gelu
res = (GEGLU_HALF * res * (ACCUMULATOR_VAL_ONE + (erf(res * GEGLU_MULT))));
#elif GLU_TYPE == 2 // Gelu_Tanh
res = (GEGLU_HALF * res * (ACCUMULATOR_VAL_ONE + (tanh(GEGLU_SQUARE_2_OVER_PI * res * (ACCUMULATOR_VAL_ONE + GEGLU_MULT * res * res)))));
#endif
res *= (ACCUMULATOR_TYPE)input[input_idx];

output[output_idx] = TO_OUTPUT_TYPE(res);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ JitConstants SwiGLUKernelRef::GetJitConstants(const swiglu_params& params) const

jit.AddConstants({MakeJitConstant("AXIS", params.axis)});
jit.AddConstants({MakeJitConstant("SPLIT_LENGTH", params.split_length)});
jit.AddConstants({MakeJitConstant("GLU_TYPE", params.glu_type)});
const std::string type_suffix = (GetAccumulatorType(params) == Datatype::F32) ? "f" : "h";
if (params.glu_type == ov::intel_gpu::op::SwiGLU::GluType::Gelu) {
jit.AddConstants({MakeJitConstant("GEGLU_HALF", "0.5" + type_suffix)});
jit.AddConstants({MakeJitConstant("GEGLU_MULT", "0.7071067811865475" + type_suffix)});
} else if (params.glu_type == ov::intel_gpu::op::SwiGLU::GluType::Gelu_Tanh) {
jit.AddConstants({MakeJitConstant("GEGLU_HALF", "0.5" + type_suffix)});
jit.AddConstants({MakeJitConstant("GEGLU_MULT", "0.044715" + type_suffix)});
jit.AddConstants({MakeJitConstant("GEGLU_SQUARE_2_OVER_PI", "0.79788458347320556640625" + type_suffix)});
}
jit.AddConstants({MakeJitConstant("SPLIT_TO_GLU_IDX", params.split_to_glu_idx)});
jit.Merge(MakeTypeJitConstants(GetAccumulatorType(params), "ACCUMULATOR"));
jit.Merge(GetTensorFriendlyWorkGroupsJit(params.outputs[0]));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@
#pragma once

#include "kernel_base_opencl.h"
#include "intel_gpu/op/swiglu.hpp"

namespace kernel_selector {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// swiglu_params
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct swiglu_params : public base_params {
swiglu_params() : base_params(KernelType::SWIGLU), axis(0), split_length(0) {}
swiglu_params() : base_params(KernelType::SWIGLU), axis(0), split_length(0),
glu_type(ov::intel_gpu::op::SwiGLU::GluType::Swish), split_to_glu_idx(0) {}
int32_t axis;
int32_t split_length;
ov::intel_gpu::op::SwiGLU::GluType glu_type;
int32_t split_to_glu_idx;
};

class SwiGLUKernelRef : public KernelBaseOpenCL {
Expand Down
4 changes: 4 additions & 0 deletions src/plugins/intel_gpu/src/plugin/ops/swiglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ static void CreateSwiGLUOp(ProgramBuilder& p, const std::shared_ptr<op::SwiGLU>&
inputs[0],
op->get_axis(),
op->get_split_lengths(),
op->get_glu_type(),
op->get_split_to_glu_idx(),
cldnn::tensor());
prim.output_data_types = get_output_data_types(op);
p.add_primitive(*op, prim);
Expand All @@ -36,6 +38,8 @@ static void CreateSwiGLUOp(ProgramBuilder& p, const std::shared_ptr<op::SwiGLU>&
inputs[0],
op->get_axis(),
op->get_split_lengths(),
op->get_glu_type(),
op->get_split_to_glu_idx(),
tensor_from_dims(op->get_output_shape(0)));
prim.output_data_types = get_output_data_types(op);
p.add_primitive(*op, prim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ namespace op {
SwiGLU::SwiGLU(const Output<Node>& data,
int64_t axis,
int64_t split_lengths,
const GluType glu_type,
const size_t split_to_glu_idx,
const ov::element::Type output_type)
: Op({data}), m_axis(axis), m_split_lengths(split_lengths), m_output_type(output_type) {
: Op({data}), m_axis(axis), m_split_lengths(split_lengths),
m_glu_type(glu_type), m_split_to_glu_idx(split_to_glu_idx), m_output_type(output_type) {
validate_and_infer_types();
}

Expand Down Expand Up @@ -44,6 +47,8 @@ std::shared_ptr<Node> SwiGLU::clone_with_new_inputs(const ov::OutputVector& new_
return std::make_shared<SwiGLU>(new_args.at(0),
m_axis,
m_split_lengths,
m_glu_type,
m_split_to_glu_idx,
m_output_type);
}

Expand Down
37 changes: 32 additions & 5 deletions src/plugins/intel_gpu/src/plugin/transformations/swiglu_fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/swish.hpp"
#include "openvino/op/gelu.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

Expand All @@ -19,6 +21,7 @@ namespace intel_gpu {

SwiGLUFusion::SwiGLUFusion() {
using namespace ov::pass::pattern;
using ov::pass::pattern::op::Or;

auto last_dim_static = [](const ov::Output<ov::Node>& output) {
auto out_ps = output.get_node()->get_output_partial_shape(0);
Expand All @@ -37,24 +40,46 @@ SwiGLUFusion::SwiGLUFusion() {

// Swish(Xw) = Xw * (1.0 + exp(-beta * Xw))
auto swish_m = wrap_type<ov::op::v4::Swish>({variadic_split_m->output(0)});
auto gelu_m = wrap_type<ov::op::v7::Gelu>({variadic_split_m->output(0)});

// Mul(Xw, Xv) = Swish(Xw) * Xv
auto mul_m = wrap_type<ov::op::v1::Multiply>({swish_m, variadic_split_m->output(1)});
auto glu_m = std::make_shared<Or>(OutputVector{swish_m, gelu_m});
auto mul_m = wrap_type<ov::op::v1::Multiply>({glu_m, variadic_split_m->output(1)});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
OPENVINO_ASSERT(pattern_map.count(mul_m));
OPENVINO_ASSERT(pattern_map.count(swish_m));
OPENVINO_ASSERT(pattern_map.count(swish_m) || pattern_map.count(gelu_m));
OPENVINO_ASSERT(pattern_map.count(variadic_split_m));
OPENVINO_ASSERT(pattern_map.count(split_lengths_const_m));
OPENVINO_ASSERT(pattern_map.count(axis_const_m));
auto mul = std::dynamic_pointer_cast<ov::op::v1::Multiply>(pattern_map.at(mul_m).get_node_shared_ptr());
if (!mul || transformation_callback(mul))
return false;

size_t split_in_idx = ov::is_type<ov::op::v4::Swish>(mul->get_input_node_shared_ptr(0)) ? 1 : 0;
if (mul->input_value(split_in_idx).get_index() != 1)
return false;
auto isSwiGLU = pattern_map.count(swish_m);
auto isGeGLU = pattern_map.count(gelu_m);
size_t split_to_glu_idx = 0;
ov::intel_gpu::op::SwiGLU::GluType glu_type;

if (isSwiGLU) {
auto swish = std::dynamic_pointer_cast<ov::op::v4::Swish>(pattern_map.at(swish_m).get_node_shared_ptr());
glu_type = ov::intel_gpu::op::SwiGLU::GluType::Swish;
split_to_glu_idx = swish->input_value(0).get_index();

size_t split_in_idx = ov::is_type<ov::op::v4::Swish>(mul->get_input_node_shared_ptr(0)) ? 1 : 0;
if (mul->input_value(split_in_idx).get_index() == split_to_glu_idx)
return false;
} else if (isGeGLU) {
auto gelu = std::dynamic_pointer_cast<ov::op::v7::Gelu>(pattern_map.at(gelu_m).get_node_shared_ptr());
glu_type = (gelu->get_approximation_mode() == ov::op::GeluApproximationMode::ERF) ? ov::intel_gpu::op::SwiGLU::GluType::Gelu
: ov::intel_gpu::op::SwiGLU::GluType::Gelu_Tanh;
split_to_glu_idx = gelu->input_value(0).get_index();

size_t split_in_idx = ov::is_type<ov::op::v7::Gelu>(mul->get_input_node_shared_ptr(0)) ? 1 : 0;
if (mul->input_value(split_in_idx).get_index() == split_to_glu_idx)
return false;
}

auto variadic_split = std::dynamic_pointer_cast<ov::op::v1::VariadicSplit>(pattern_map.at(variadic_split_m).get_node_shared_ptr());
auto variadic_split_in_ps = variadic_split->get_input_partial_shape(0);
Expand All @@ -80,6 +105,8 @@ SwiGLUFusion::SwiGLUFusion() {
auto swiglu = std::make_shared<op::SwiGLU>(data,
axis_value,
split_lengths_value,
glu_type,
split_to_glu_idx,
output_type);
swiglu->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), swiglu);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TEST(swiglu_gpu_test, swiglu_test_bfyx_dyn) {

topology topology;
topology.add(input_layout("input", input_layout_dynamic));
topology.add(swiglu("swiglu", input_info("input"), -1, 3, tensor()));
topology.add(swiglu("swiglu", input_info("input"), -1, 3, ov::intel_gpu::op::SwiGLU::GluType::Swish, 0, tensor()));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
Expand Down
Loading

0 comments on commit 44e68a2

Please sign in to comment.