Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Introduce memory metadata tagging pass #6669

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ runtime.python_library(
],
)

runtime.python_library(
name = "int4_weight_only_quantizer",
srcs = [
"int4_weight_only_quantizer.py",
],
visibility = [
"//executorch/backends/...",
],
deps = [
"//executorch/backends/vulkan:custom_ops_lib",
"//pytorch/ao:torchao",
]
)

runtime.python_library(
name = "remove_local_scalar_dense",
srcs = ["remove_local_scalar_dense_ops.py"],
Expand All @@ -30,17 +44,18 @@ runtime.python_library(
)

runtime.python_library(
name = "int4_weight_only_quantizer",
srcs = [
"int4_weight_only_quantizer.py",
],
name = "tag_memory_meta_pass",
srcs = ["tag_memory_meta_pass.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//executorch/backends/vulkan:custom_ops_lib",
"//pytorch/ao:torchao",
]
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/backends/vulkan:utils_lib",
"//executorch/backends/vulkan/serialization:lib",
],
)

runtime.python_library(
Expand All @@ -56,5 +71,6 @@ runtime.python_library(
":insert_prepack_nodes",
":int4_weight_only_quantizer",
":remove_local_scalar_dense",
":tag_memory_meta_pass"
]
)
2 changes: 2 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass

__all__ = [
"insert_prepack_nodes",
"VkInt4WeightOnlyQuantizer",
"RemoveLocalScalarDenseOpsTransform",
"TagMemoryMetaPass",
]
236 changes: 236 additions & 0 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from copy import deepcopy
from typing import Set

import executorch.backends.vulkan.utils as utils

import torch

from executorch.backends.vulkan.op_registry import get_op_features, has_impl

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult

from torch._subclasses.fake_tensor import FakeTensor

from torch.fx.passes.tools_common import NodeList
from torch.fx.passes.utils.fuser_utils import topo_sort

logger: logging.Logger = logging.getLogger("")
logger.setLevel(logging.INFO)


def set_memory_metadata(
node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout
) -> None:
utils.set_node_spec_attr(node, "vk_storage_type", storage)
utils.set_node_spec_attr(node, "vk_memory_layout", layout)


class TagMemoryMetaPass(ExportPass):
"""
There are a variety of ways that tensors can be represented in Vulkan. The two main
descriptors for how a tensor is laid out in memory is:

1. Storage Type (buffer or texture)
2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.)

Due to the differences between buffers and textures, and the differences between
different memory layouts, an implementation for an operator may only support a
specific set of (storage type, memory layout) combinations.

Furthermore, if an operator implementation supports multiple (storage type, memory
layout) combinations, there may be a "preferred" setting which results in optimal
performance.

This pass is responsible for ensuring that all tensors participating in an operator
call have a valid/optimal (storage type, memory layout) setting, and insert
transition operators to transfer input tensors to the correct memory settings when
necessary.
"""

def __init__(
self,
texture_limits: utils.ImageExtents,
default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D,
default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED,
):
super().__init__()
self.default_storage: VkStorageType = default_storage_type
self.default_layout: VkMemoryLayout = default_memory_layout
self.texture_limits = texture_limits

def propose_node_storage(
self,
node: torch.fx.Node,
) -> VkStorageType:
"""
Uses the operator registry to determine the storage type that should be used for
a given node. The storage type is determined with the following priorities:
1. In some cases, a tensor involved in the computation may be too large to be
represented as a texture. If this is the case, the node is "opinionated" and
buffer representation must be used.
1. If the operator called by the node indicates an optimal storage type, or only
supports a single storage type, use that storage type. If either is true,
then the node is considered to be opinionated as well. If multiple storage
and no preferred storage type is indicated, then the node is not opinionated;
go to the next step.
2. If the node's arguments already have memory metadata annotations, then
preserve the settings of the first argument. Otherwise, proceed to the next
step.
3. Recursively search the node's uses to see if any subsequent uses are
opinionated; inherit the settings of the first opinionated node. If no
opinionated user can be found, then proceed to the last step.
4. Use the default storage type setting.
"""
# The node may have an input/output tensor that is too big to be stored in a
# texture. In this case, buffer storage must be used. Note that the partitioner
# has already checked for the fact that buffer storage is supported by the
# operator.
if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0:
return VkStorageType.BUFFER

valid_storage_types: Set[VkStorageType] = utils.all_storage_types

# pyre-ignore
if has_impl(node.target):
# pyre-ignore
features = get_op_features(node.target)
valid_storage_types = features.supported_storage_types()
storage = features.propose_storage_type()
if storage is not None:
return storage

for arg in node.args:
if isinstance(arg, torch.fx.Node) and isinstance(
arg.meta["val"], FakeTensor
):
storage = utils.get_node_storage_type(arg)
if storage is not None and storage in valid_storage_types:
return storage

# If no storage type has been resolved yet, assume the optimal storage type of
# the first opinionated user. This search is recursive.
for user in node.users:
optimal_storage = self.propose_node_storage(user)
if optimal_storage is not None:
return optimal_storage

if self.default_storage in valid_storage_types:
return self.default_storage
else:
return next(iter(valid_storage_types))

def propose_node_layout(
self,
node: torch.fx.Node,
storage: VkStorageType,
) -> VkMemoryLayout:
"""
Performs the same steps as propose_node_storage, but detects the memory layout
that should be used for the specific storage type. The same prioritization logic
is applied.
"""
valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts
# pyre-ignore
if has_impl(node.target):
# pyre-ignore
features = get_op_features(node.target)
valid_layouts = features.supported_memory_layouts(storage)
layout = features.propose_memory_layout(storage)
if layout is not None:
return layout

for arg in node.args:
if isinstance(arg, torch.fx.Node) and isinstance(
arg.meta["val"], FakeTensor
):
layout = utils.get_node_memory_layout(arg)
if layout is not None and layout in valid_layouts:
return layout

# If no storage type has been resolved yet, assume the optimal storage type of
# the first opinionated user. This search is recursive.
for user in node.users:
optimal_storage = self.propose_node_layout(user, storage)
if optimal_storage is not None:
return optimal_storage

# As a last resort, return the default storage type that should be used.
if self.default_layout in valid_layouts:
return self.default_layout
else:
return next(iter(valid_layouts))

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))

for node in sorted_nodes:
if not isinstance(node.meta["val"], FakeTensor):
continue

if node.target == exir_ops.edge.et_vk.prepack.default:
continue

storage = self.propose_node_storage(node)
layout = self.propose_node_layout(node, storage)

set_memory_metadata(node, storage, layout)

inserting_transitions_for_node = False
for i, arg in enumerate(node.args):
if not isinstance(arg, torch.fx.Node):
continue
if not isinstance(arg.meta["val"], FakeTensor):
continue

arg_storage = utils.get_node_storage_type(arg)
arg_layout = utils.get_node_memory_layout(arg)

if arg_storage is None:
utils.set_node_spec_attr(arg, "vk_storage_type", storage)
arg_storage = storage
if arg_layout is None:
utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
arg_layout = layout

if arg_storage == storage and arg_layout == layout:
continue

if not inserting_transitions_for_node:
inserting_transitions_for_node = True
logger.info(
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
)

logger.info(
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
)

# Insert a clone node to copy the original tensor to a tensor with the
# desired storage type and memory layout.
with graph_module.graph.inserting_before(node):
clone_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
(arg,),
)
clone_node.meta["val"] = arg.meta["val"]
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
clone_node.meta["spec"].const = False
set_memory_metadata(clone_node, storage, layout)
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

return PassResult(graph_module, True)
8 changes: 5 additions & 3 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,11 @@ def op_node_is_compatible(
# If there are no valid texture memory layouts, then buffer storage must be
# supported by the operator implementation.
if len(valid_texture_layouts) == 0:
# TODO: once memory metadata tagging pass is implemented, check that the
# op impl supports buffers instead
return False, "requires buffer representation"
compatible = VkStorageType.BUFFER in features.supported_storage_types()
reason = "op is compatible"
if not compatible:
reason = "op requires buffers which is not supported by op impl"
return compatible, reason

op_available_layouts = features.supported_memory_layouts(
VkStorageType.TEXTURE_3D
Expand Down
16 changes: 16 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

import torch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)
from executorch.backends.vulkan.utils import (
is_constant,
is_get_attr_node,
Expand Down Expand Up @@ -169,6 +174,15 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
if spec.mem_obj_id is not None:
mem_obj_id = spec.mem_obj_id

storage_type = VkStorageType.DEFAULT_STORAGE
memory_layout = VkMemoryLayout.DEFAULT_LAYOUT
if hasattr(spec, "vk_storage_type"):
# pyre-ignore[16]
storage_type = spec.vk_storage_type
if hasattr(spec, "vk_memory_layout"):
# pyre-ignore[16]
memory_layout = spec.vk_memory_layout

new_id = len(self.values)
self.values.append(
vk_graph_schema.VkValue(
Expand All @@ -177,6 +191,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
dims=spec.shape,
constant_id=constant_id,
mem_obj_id=mem_obj_id,
storage_type=storage_type,
memory_layout=memory_layout,
)
)
)
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,19 @@ class VkStorageType(IntEnum):
TEXTURE_2D = 2
DEFAULT_STORAGE = 255

def __str__(self) -> str:
return self.name


class VkMemoryLayout(IntEnum):
TENSOR_WIDTH_PACKED = 0
TENSOR_HEIGHT_PACKED = 1
TENSOR_CHANNELS_PACKED = 2
DEFAULT_LAYOUT = 255

def __str__(self) -> str:
return self.name


@dataclass
class VkTensor:
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def define_common_targets(is_fbcode = False):
],
deps = [
"//caffe2:torch",
"//executorch/exir:tensor",
"//executorch/backends/vulkan/serialization:lib",
]
)

Expand Down
Loading
Loading