diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 93e703dc03..86dfb74d06 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -38,11 +38,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # Unary operators exir_ops.edge.aten.abs.default, exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.sigmoid.default, - exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.tanh.default, # Matrix multiplication operators exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.mm.default, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 00d8cbd3c5..fbb49f4799 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -196,6 +196,10 @@ class ComputeGraph final { } } + std::string extract_string(const ValueRef idx) { + return values_.at(idx).toString(); + } + inline std::vector>& prepack_nodes() { return prepack_nodes_; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 2d8ec36d9a..14e4e111a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -12,9 +12,11 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) + - NAME: gelu + OPERATOR: 0.5 * X * (1 + tanh(sqrt(2 / 3.141593) * (X + 0.044715 * X * X * X))) - NAME: sigmoid OPERATOR: 1 / (1 + exp(-1 * X)) - - NAME: tanh - OPERATOR: tanh(clamp(X, -15.0, 15.0)) - NAME: sqrt OPERATOR: sqrt(X) + - NAME: tanh + OPERATOR: tanh(clamp(X, -15.0, 15.0)) diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index b2fb1135d7..4dd615cda1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -100,10 +100,18 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { kClampShaderName); \ } +void gelu(ComputeGraph& graph, const std::vector& args) { + // args[1] is the `approximate` string + // https://fburl.com/code/9omngmyo + // currently only `approximate = "tanh"` is supported + return add_unary_op_node( + graph, args[0], kDummyFloat, kDummyFloat, args[2], "gelu"); +} + DEFINE_ACTIVATION_FN(abs); DEFINE_ACTIVATION_FN(sigmoid); -DEFINE_ACTIVATION_FN(tanh); DEFINE_ACTIVATION_FN(sqrt); +DEFINE_ACTIVATION_FN(tanh); DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); DEFINE_RELU_FN(relu); @@ -111,11 +119,12 @@ DEFINE_RELU_FN(relu); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); VK_REGISTER_OP(aten.clamp.default, clamp); + VK_REGISTER_OP(aten.gelu.default, gelu); VK_REGISTER_OP(aten.hardtanh.default, hardtanh); VK_REGISTER_OP(aten.relu.default, relu); VK_REGISTER_OP(aten.sigmoid.default, sigmoid); - VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.sqrt.default, sqrt); + VK_REGISTER_OP(aten.tanh.default, tanh); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 9f47284485..a1e6227a22 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -727,6 +727,18 @@ def get_native_batch_norm_inputs(): return test_suite +def get_gelu_inputs(): + test_suite = VkTestSuite( + [ + ((M1), "tanh"), + ((M1, M2), "tanh"), + ((S1, M1, M2), "tanh"), + ((S1, S2, S2, M2), "tanh"), + ] + ) + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -755,4 +767,5 @@ def get_native_batch_norm_inputs(): "aten._softmax.default": get_softmax_inputs(), "aten._log_softmax.default": get_softmax_inputs(), "aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(), + "aten.gelu.default": get_gelu_inputs(), } diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index fa02986a1c..c803f76792 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -24,6 +24,7 @@ OPT_LAYOUT, OPT_MEMORY_FORMAT, OPT_SCALAR_TYPE, + STRING, TENSOR_VECTOR, TestSuite, TestSuiteGen, @@ -351,6 +352,8 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 or ref.src_cpp_type == OPT_MEMORY_FORMAT ): ret_str += "add_none(); \n" + elif ref.src_cpp_type == STRING: + ret_str += f"add_string(std::string({ref.src_cpp_name})); \n" elif ref.src_cpp_type == TWO_TENSOR_TUPLE: ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n" elif ref.src_cpp_type == THREE_TENSOR_TUPLE: diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index 6dac97583c..c1c6249e27 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -29,6 +29,7 @@ OPT_LAYOUT = "::std::optional" OPT_MEMORY_FORMAT = "::std::optional" OPT_SCALAR_TYPE = "::std::optional" +STRING = "c10::string_view" TWO_TENSOR_TUPLE = "::std::tuple" THREE_TENSOR_TUPLE = "::std::tuple" TENSOR_VECTOR = "::std::vector" @@ -166,6 +167,8 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901 ret_str += "std::nullopt;" else: ret_str += f"{str(data)};" + elif cpp_type == STRING: + ret_str += f'c10::string_view("{data}");' elif ( cpp_type == OPT_SCALAR_TYPE or cpp_type == OPT_LAYOUT diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 531f1d28a9..2cd3bc3a27 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1034,3 +1034,14 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_gelu(self): + class GeluModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU(approximate="tanh") + + def forward(self, x): + return self.gelu(x) + + self.lower_unary_module_and_test_output(GeluModule())