Skip to content

Commit

Permalink
gelu (#3573)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3573

- implemented `gelu` op, only `approximate="tanh"` is supported for now
- added string data type in the codegen
- added `extract_string` in `ComputeGraph.h`

Reviewed By: yipjustin, jorgep31415

Differential Revision: D57194536

fbshipit-source-id: 4c6c2e126fe35021248759ad4578d8f6aec9bffc
  • Loading branch information
copyrightly authored and facebook-github-bot committed May 11, 2024
1 parent 43bfcd2 commit 60c94e8
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 5 deletions.
3 changes: 2 additions & 1 deletion backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# Unary operators
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.tanh.default,
# Matrix multiplication operators
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.mm.default,
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ class ComputeGraph final {
}
}

std::string extract_string(const ValueRef idx) {
return values_.at(idx).toString();
}

inline std::vector<std::unique_ptr<PrepackNode>>& prepack_nodes() {
return prepack_nodes_;
}
Expand Down
6 changes: 4 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ unary_op:
OPERATOR: abs(X)
- NAME: clamp
OPERATOR: clamp(X, A, B)
- NAME: gelu
OPERATOR: 0.5 * X * (1 + tanh(sqrt(2 / 3.141593) * (X + 0.044715 * X * X * X)))
- NAME: sigmoid
OPERATOR: 1 / (1 + exp(-1 * X))
- NAME: tanh
OPERATOR: tanh(clamp(X, -15.0, 15.0))
- NAME: sqrt
OPERATOR: sqrt(X)
- NAME: tanh
OPERATOR: tanh(clamp(X, -15.0, 15.0))
13 changes: 11 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,31 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
kClampShaderName); \
}

void gelu(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// args[1] is the `approximate` string
// https://fburl.com/code/9omngmyo
// currently only `approximate = "tanh"` is supported
return add_unary_op_node(
graph, args[0], kDummyFloat, kDummyFloat, args[2], "gelu");
}

DEFINE_ACTIVATION_FN(abs);
DEFINE_ACTIVATION_FN(sigmoid);
DEFINE_ACTIVATION_FN(tanh);
DEFINE_ACTIVATION_FN(sqrt);
DEFINE_ACTIVATION_FN(tanh);
DEFINE_CLAMP_FN(clamp);
DEFINE_CLAMP_FN(hardtanh);
DEFINE_RELU_FN(relu);

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.abs.default, abs);
VK_REGISTER_OP(aten.clamp.default, clamp);
VK_REGISTER_OP(aten.gelu.default, gelu);
VK_REGISTER_OP(aten.hardtanh.default, hardtanh);
VK_REGISTER_OP(aten.relu.default, relu);
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
VK_REGISTER_OP(aten.tanh.default, tanh);
VK_REGISTER_OP(aten.sqrt.default, sqrt);
VK_REGISTER_OP(aten.tanh.default, tanh);
}

} // namespace vkcompute
13 changes: 13 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,18 @@ def get_native_batch_norm_inputs():
return test_suite


def get_gelu_inputs():
test_suite = VkTestSuite(
[
((M1), "tanh"),
((M1, M2), "tanh"),
((S1, M1, M2), "tanh"),
((S1, S2, S2, M2), "tanh"),
]
)
return test_suite


test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
Expand Down Expand Up @@ -755,4 +767,5 @@ def get_native_batch_norm_inputs():
"aten._softmax.default": get_softmax_inputs(),
"aten._log_softmax.default": get_softmax_inputs(),
"aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(),
"aten.gelu.default": get_gelu_inputs(),
}
3 changes: 3 additions & 0 deletions backends/vulkan/test/op_tests/utils/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OPT_LAYOUT,
OPT_MEMORY_FORMAT,
OPT_SCALAR_TYPE,
STRING,
TENSOR_VECTOR,
TestSuite,
TestSuiteGen,
Expand Down Expand Up @@ -351,6 +352,8 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
or ref.src_cpp_type == OPT_MEMORY_FORMAT
):
ret_str += "add_none(); \n"
elif ref.src_cpp_type == STRING:
ret_str += f"add_string(std::string({ref.src_cpp_name})); \n"
elif ref.src_cpp_type == TWO_TENSOR_TUPLE:
ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n"
elif ref.src_cpp_type == THREE_TENSOR_TUPLE:
Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/test/op_tests/utils/codegen_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
OPT_LAYOUT = "::std::optional<at::Layout>"
OPT_MEMORY_FORMAT = "::std::optional<at::MemoryFormat>"
OPT_SCALAR_TYPE = "::std::optional<at::ScalarType>"
STRING = "c10::string_view"
TWO_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor>"
THREE_TENSOR_TUPLE = "::std::tuple<at::Tensor,at::Tensor,at::Tensor>"
TENSOR_VECTOR = "::std::vector<at::Tensor>"
Expand Down Expand Up @@ -166,6 +167,8 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901
ret_str += "std::nullopt;"
else:
ret_str += f"{str(data)};"
elif cpp_type == STRING:
ret_str += f'c10::string_view("{data}");'
elif (
cpp_type == OPT_SCALAR_TYPE
or cpp_type == OPT_LAYOUT
Expand Down
11 changes: 11 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,14 @@ def forward(self, x):
sample_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)

def test_vulkan_backend_gelu(self):
class GeluModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.gelu = torch.nn.GELU(approximate="tanh")

def forward(self, x):
return self.gelu(x)

self.lower_unary_module_and_test_output(GeluModule())

0 comments on commit 60c94e8

Please sign in to comment.