Skip to content

[ET-VK] Serialize list types from function args #2404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ class GraphBuilder {
ref_mapping_[fb_id] = ref;
}

template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
add_scalar_list_to_graph(const uint32_t fb_id, std::vector<T>&& value) {
ValueRef ref = compute_graph_->add_scalar_list(std::move(value));
ref_mapping_[fb_id] = ref;
}

void add_value_list_to_graph(
const uint32_t fb_id,
std::vector<ValueRef>&& value) {
ValueRef ref = compute_graph_->add_value_list(std::move(value));
ref_mapping_[fb_id] = ref;
}

void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) {
const auto fb_str = value->value_as_String()->string_val();
std::string string(fb_str->cbegin(), fb_str->cend());
Expand All @@ -150,6 +164,34 @@ class GraphBuilder {
case vkgraph::GraphTypes::VkTensor:
add_tensor_to_graph(fb_id, value->value_as_VkTensor());
break;
case vkgraph::GraphTypes::IntList:
add_scalar_list_to_graph(
fb_id,
std::vector<int64_t>(
value->value_as_IntList()->items()->cbegin(),
value->value_as_IntList()->items()->cend()));
break;
case vkgraph::GraphTypes::DoubleList:
add_scalar_list_to_graph(
fb_id,
std::vector<double>(
value->value_as_DoubleList()->items()->cbegin(),
value->value_as_DoubleList()->items()->cend()));
break;
case vkgraph::GraphTypes::BoolList:
add_scalar_list_to_graph(
fb_id,
std::vector<bool>(
value->value_as_BoolList()->items()->cbegin(),
value->value_as_BoolList()->items()->cend()));
break;
case vkgraph::GraphTypes::ValueList:
add_value_list_to_graph(
fb_id,
std::vector<ValueRef>(
value->value_as_ValueList()->items()->cbegin(),
value->value_as_ValueList()->items()->cend()));
break;
case vkgraph::GraphTypes::String:
add_string_to_graph(fb_id, value);
break;
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ ValueRef ComputeGraph::add_staging(
return idx;
}

ValueRef ComputeGraph::add_value_list(std::vector<ValueRef>&& value) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(value));
return idx;
}

ValueRef ComputeGraph::add_string(std::string&& str) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(str));
Expand Down
14 changes: 8 additions & 6 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,13 @@ class ComputeGraph final {

template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
add_scalar_list(std::vector<T>&& values);
add_scalar(T value);

template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
add_scalar(T value);
add_scalar_list(std::vector<T>&& value);

ValueRef add_value_list(std::vector<ValueRef>&& value);

ValueRef add_string(std::string&& str);

Expand Down Expand Up @@ -212,17 +214,17 @@ class ComputeGraph final {

template <typename T>
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
ComputeGraph::add_scalar_list(std::vector<T>&& values) {
ComputeGraph::add_scalar(T value) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(values));
values_.emplace_back(value);
return idx;
}

template <typename T>
inline typename std::enable_if<is_valid_scalar_type<T>::value, ValueRef>::type
ComputeGraph::add_scalar(T value) {
ComputeGraph::add_scalar_list(std::vector<T>&& value) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(value);
values_.emplace_back(std::move(value));
return idx;
}

Expand Down
54 changes: 46 additions & 8 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Union
from typing import cast, List, Optional, Union

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

Expand All @@ -15,8 +15,8 @@
from torch.export import ExportedProgram
from torch.fx import Node

_ScalarType = Union[int, bool, float]
_Argument = Union[Node, int, bool, float, str]
_ScalarType = Union[bool, int, float]
_Argument = Union[Node, List[Node], _ScalarType, List[_ScalarType], str]


class VkGraphBuilder:
Expand Down Expand Up @@ -150,14 +150,46 @@ def create_tensor_values(self, node: Node) -> int:
"Creating values for nodes with collection types is not supported yet."
)

def create_value_list_value(self, arg: List[Node]) -> int:
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.ValueList(
items=[self.get_or_create_value_for(e) for e in arg]
)
)
)
return len(self.values) - 1

def create_scalar_value(self, scalar: _ScalarType) -> int:
new_id = len(self.values)
if isinstance(scalar, int):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
if isinstance(scalar, float):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
if isinstance(scalar, bool):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar)))
elif isinstance(scalar, int):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar)))
elif isinstance(scalar, float):
self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar)))
return new_id

def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
new_id = len(self.values)
if isinstance(arg[0], bool):
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg])
)
)
elif isinstance(arg[0], int):
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.IntList(items=[cast(int, e) for e in arg])
)
)
elif isinstance(arg[0], float):
self.values.append(
vk_graph_schema.VkValue(
vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg])
)
)
return new_id

def create_string_value(self, string: str) -> int:
Expand All @@ -174,8 +206,14 @@ def get_or_create_value_for(self, arg: _Argument):
return self.node_to_value_ids[arg]
# Return id for a newly created value
return self.create_tensor_values(arg)
elif isinstance(arg, (int, float, bool)):
elif isinstance(arg, list) and isinstance(arg[0], Node):
# pyre-ignore[6]
return self.create_value_list_value(arg)
elif isinstance(arg, _ScalarType):
return self.create_scalar_value(arg)
elif isinstance(arg, list) and isinstance(arg[0], _ScalarType):
# pyre-ignore[6]
return self.create_scalar_list_value(arg)
elif isinstance(arg, str):
return self.create_string_value(arg)
else:
Expand Down