Skip to content

Commit

Permalink
[ET-VK] Miscellaneous fixes for Vulkan docs (#5986)
Browse files Browse the repository at this point in the history
Miscellaneous fixes for Vulkan docs (#5972)

Summary:
Pull Request resolved: #5972

## Context

Implement various fixes to Vulkan delegate while performing QA for Vulkan docs. I elected to package everything into this one diff so that it is easy to cherrypick into the 0.4 release.

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: jorgep31415

Differential Revision: D64022818

Pulled By: SS-JIA

fbshipit-source-id: 35782970e9db1ab33154ccbae2e10c77d911c041
(cherry picked from commit 513d166)

Co-authored-by: Stephen Jia <ssjia@meta.com>
  • Loading branch information
pytorchbot and SS-JIA authored Oct 8, 2024
1 parent 60cc6bc commit 0fa033c
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 91 deletions.
65 changes: 39 additions & 26 deletions backends/vulkan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,47 @@ currently in development:
## End to End Example

To further understand the features of the Vulkan Delegate and how to use it,
consider the following end to end example with MobileNet V2.
consider the following end to end example with a simple single operator model.

### Compile and lower a model to the Vulkan Delegate

Assuming ExecuTorch has been set up and installed, the following script can be
used to produce a lowered MobileNet V2 model as `vulkan_mobilenetv2.pte`.

Once ExecuTorch has been set up and installed, the following script can be used
to generate a simple model and lower it to the Vulkan delegate.

```
# Note: this script is the same as the script from the "Setting up ExecuTorch"
# page, with one minor addition to lower to the Vulkan backend.
import torch
import torchvision.models as models
from torch.export import export
from executorch.exir import to_edge
from torch.export import export, ExportedProgram
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager, to_edge
from executorch.exir.backend.backend_api import to_backend
mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1, 3, 224, 224), )
# Start with a PyTorch model that adds two input tensors (matrices)
class Add(torch.nn.Module):
def __init__(self):
super(Add, self).__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y
exported_program: ExportedProgram = export(mobilenet_v2, sample_inputs)
edge: EdgeProgramManager = to_edge(exported_program)
# 1. torch.export: Defines the program with the ATen operator set.
aten_dialect = export(Add(), (torch.ones(1), torch.ones(1)))
# Lower the model to Vulkan backend
edge = edge.to_backend(VulkanPartitioner())
# 2. to_edge: Make optimizations for Edge devices
edge_program = to_edge(aten_dialect)
# 2.1 Lower to the Vulkan backend
edge_program = edge_program.to_backend(VulkanPartitioner())
exec_prog = edge.to_executorch()
# 3. to_executorch: Convert the graph to an ExecuTorch program
executorch_program = edge_program.to_executorch()
with open("vulkan_mobilenetv2.pte", "wb") as file:
exec_prog.write_to_file(file)
# 4. Save the compiled .pte program
with open("vk_add.pte", "wb") as file:
file.write(executorch_program.buffer)
```

Like other ExecuTorch delegates, a model can be lowered to the Vulkan Delegate
Expand All @@ -122,29 +133,31 @@ will be executed on the GPU.


::::{note}
The [Vulkan partitioner code](https://github.com/pytorch/executorch/blob/main/backends/vulkan/partitioner/vulkan_partitioner.py)
can be inspected to examine which ops are currently implemented in the Vulkan
delegate.
The [supported ops list](https://github.com/pytorch/executorch/blob/main/backends/vulkan/partitioner/supported_ops.py)
Vulkan partitioner code can be inspected to examine which ops are currently
implemented in the Vulkan delegate.
::::

### Build Vulkan Delegate libraries

The easiest way to build and test the Vulkan Delegate is to build for Android
and test on a local Android device. Android devices have built in support for
Vulkan, and the Android NDK ships with a GLSL compiler, which is needed to
Vulkan, and the Android NDK ships with a GLSL compiler which is needed to
compile the Vulkan Compute Library's GLSL compute shaders.

The Vulkan Delegate libraries can be built by setting `-DEXECUTORCH_BUILD_VULKAN=ON`
when building with CMake.

First, make sure that you have the Android NDK installed - Android NDK r25c is
recommended. The Android SDK should also be installed so that you have access
to `adb`.
First, make sure that you have the Android NDK installed; any NDK version past
NDK r19c should work. Note that the examples in this doc have been validated with
NDK r25. The Android SDK should also be installed so that you have access to `adb`.

The instructions in this page assumes that the following environment variables
are set.

```shell
# Recommended version is Android NDK r25c.
export ANDROID_NDK=<path_to_ndk>
# Select an appropriate Android ABI
# Select the appropriate Android ABI for your device
export ANDROID_ABI=arm64-v8a
# All subsequent commands should be performed from ExecuTorch repo root
cd <path_to_executorch_root>
Expand Down Expand Up @@ -183,10 +196,10 @@ GPU!
cmake --build cmake-android-out --target vulkan_executor_runner -j32

# Push model to device
adb push vulkan_mobilenetv2.pte /data/local/tmp/vulkan_mobilenetv2.pte
adb push vk_add.pte /data/local/tmp/vk_add.pte
# Push binary to device
adb push cmake-android-out/backends/vulkan/vulkan_executor_runner /data/local/tmp/runner_bin

# Run the model
adb shell /data/local/tmp/runner_bin --model_path /data/local/tmp/vulkan_mobilenetv2.pte
adb shell /data/local/tmp/runner_bin --model_path /data/local/tmp/vk_add.pte
```
103 changes: 40 additions & 63 deletions backends/vulkan/docs/android_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ is a native GPU delegate for ExecuTorch.
::::{grid} 2
:::{grid-item-card} What you will learn in this tutorial:
:class-card: card-content
* How to export the Stories 110M parameter model with partial GPU delegation
* How to export the Llama3.2-1B parameter model with partial GPU delegation
* How to execute the partially delegated model on Android
:::
:::{grid-item-card} Prerequisites:
:class-card: card-prerequisites
* Follow [**Setting up ExecuTorch**](./getting-started-setup.md)
* Follow [**Setting up the ExecuTorch LLaMA Android Demo App**](./llm/llama-demo-android.md)
* It is also recommended that you read through [**ExecuTorch Vulkan Delegate**](./native-delegates-executorch-vulkan-delegate.md) and follow the example in that page
:::
::::

Expand All @@ -23,65 +23,55 @@ Note that all the steps below should be performed from the ExecuTorch repository
root directory, and assumes that you have gone through the steps of setting up
ExecuTorch.

You should also refer to the **Prerequisites** section of the [**Setting up the ExecuTorch LLaMA Android Demo App**](./llm/llama-demo-android.md)
Tutorial in order to install the specified versions of the Android NDK and the
Android SDK.
It is also assumed that the Android NDK and Android SDK is installed, and the
following environment examples are set.

```shell
# Recommended version is Android NDK r25c.
export ANDROID_NDK=<path_to_ndk>
# Select an appropriate Android ABI
# Select an appropriate Android ABI for your device
export ANDROID_ABI=arm64-v8a
# All subsequent commands should be performed from ExecuTorch repo root
cd <path_to_executorch_root>
# Make sure adb works
adb --version
```

## Lowering the Stories 110M model to Vulkan
## Lowering the Llama3.2-1B model to Vulkan

::::{note}
The resultant model will only be partially delegated to the Vulkan backend. In
particular, only binary arithmetic operators (`aten.add`, `aten.sub`,
`aten.mul`, `aten.div`) and the matrix multiplication operator (`aten.mm`) will
be executed on the GPU via the Vulkan delegate. The rest of the model will be
executed using Portable operators. This is because the Vulkan delegate is still
early in development and currently has limited operator coverage.
::::

First, download `stories110M.pt` and `tokenizer.model` from Github:

```shell
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
```

Next, create the params file:
`aten.mul`, `aten.div`), matrix multiplication operators (`aten.mm`, `aten.bmm`),
and linear layers (`aten.linear`) will be executed on the GPU via the Vulkan
delegate. The rest of the model will be executed using Portable operators.

```shell
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
```

Then, create a tokenizer binary file:
Operator support for LLaMA models is currently in active development; please
check out the `main` branch of the ExecuTorch repo for the latest capabilities.
::::

```shell
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
```
First, obtain the `consolidated.00.pth`, `params.json` and `tokenizer.model`
files for the `Llama3.2-1B` model from the [Llama website](https://www.llama.com/llama-downloads/).

Finally, export the `stories110M.pt` file into an ExecuTorch program:
Once the files have been downloaded, the `export_llama` script can be used to
partially lower the Llama model to Vulkan.

```shell
python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json --vulkan
# The files will usually be downloaded to ~/.llama
python -m examples.models.llama2.export_llama \
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
```

A `vulkan_llama2.pte` file should have been created as a result of the last step.
A `vulkan_llama2.pte` file should have been created as a result of running the
script.

Push the tokenizer binary and `vulkan_llama2.pte` onto your Android device:

```shell
adb mkdir /data/local/tmp/llama/
adb push tokenizer.bin /data/local/tmp/llama/
adb push vulkan_llama2.pte /data/local/tmp/llama/
adb push ~/.llama/tokenizer.model /data/local/tmp/
adb push vulkan_llama2.pte /data/local/tmp/
```

## Build and Run the LLaMA runner binary on Android
Expand All @@ -98,7 +88,8 @@ binary using the Android NDK toolchain.
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_BUILD_VULKAN=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DPYTHON_EXECUTABLE=python \
-Bcmake-android-out && \
cmake --build cmake-android-out -j16 --target install)
Expand All @@ -108,42 +99,28 @@ binary using the Android NDK toolchain.
cmake examples/models/llama2 \
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
-DANDROID_ABI=$ANDROID_ABI \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
-DPYTHON_EXECUTABLE=python \
-Bcmake-android-out/examples/models/llama2 && \
cmake --build cmake-android-out/examples/models/llama2 -j16)
```

Finally, push and run the llama runner binary on your Android device.
Finally, push and run the llama runner binary on your Android device. Note that
your device must have sufficient GPU memory to execute the model.

```shell
adb push cmake-android-out/examples/models/llama2/llama_main /data/local/tmp/llama_main

adb shell /data/local/tmp/llama_main \
--model_path=/data/local/tmp/llama/vulkan_llama2.pte \
--tokenizer_path=/data/local/tmp/llama/tokenizer.bin \
--prompt "hi" \--temperature=0
```

The following output will be produced:

```
hippo named Hippy lived in a big pond. Hippy was a very happy hippo. He liked to play...
```

## Running with the LLaMA Android Demo App

It is also possible to run the partially delegated Vulkan model inside the LLaMA
Android demo app.

First, make some modifications to the Android app setup script to make sure that
the Vulkan backend is built when building and installing ExecuTorch libraries:

```shell
# Run from executorch root directory. You can also edit this in a code editor
sed -i 's/-DEXECUTORCH_BUILD_XNNPACK=ON/-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_VULKAN=ON/g' examples/demo-apps/android/LlamaDemo/setup.sh
--model_path=/data/local/tmp/vulkan_llama2.pte \
--tokenizer_path=/data/local/tmp/tokenizer.model \
--prompt "Hello"
```

Then, Follow the instructions at [**Setting up the ExecuTorch LLaMA Android Demo App**](./llm/llama-demo-android.md)
to build and run the demo application on your Android device. Once the app
starts up, you can load and run the `vulkan_llama2.pte` model with the app.
Note that currently model inference will be very slow due to the high amount of
delegate blobs in the lowered graph, which requires a transfer to and from the
GPU for each sub graph. Performance is expected to improve drastically as more
of the model can be lowered to the Vulkan delegate, and techniques such as
quantization are supported.
15 changes: 13 additions & 2 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class VulkanSupportedOperators(OperatorSupportBase):
def __init__(self, require_dynamic_shape: bool = False) -> None:
super().__init__()
self.require_dynamic_shapes = require_dynamic_shape
# The tensor dim limit is to guard against tensors with one or more
# large dimensions, which cannot be represented by an image texture due
# to the texture axis limits.
self.tensor_dim_limit = 16384

# pyre-ignore
def node_val_is_compatible(self, node_val: Any) -> bool:
Expand All @@ -68,6 +72,10 @@ def node_val_is_compatible(self, node_val: Any) -> bool:
if node_val.dtype == torch.bool:
return False

for dim in node_val.shape:
if dim > self.tensor_dim_limit:
return False

if isinstance(node_val, (list, tuple)):
for item in node_val:
if not self.node_val_is_compatible(item):
Expand Down Expand Up @@ -100,11 +108,14 @@ def is_linear_permute(self, node: torch.fx.Node) -> bool:
if len(node.users) != 1:
return False

if list(node.users.keys())[0].target in [
first_user = list(node.users.keys())[0]
if first_user.target in [
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.addmm.default,
]:
return True
# Only mark this node if the overall linear op is valid
if self.all_args_compatible(first_user):
return True

return False

Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/vk_api/Adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include <executorch/backends/vulkan/runtime/vk_api/memory/Allocator.h>

#include <array>

namespace vkcompute {
namespace vkapi {

Expand Down

0 comments on commit 0fa033c

Please sign in to comment.