Skip to content

Commit

Permalink
Merge branch 'main' into upload-test-spec-s3
Browse files Browse the repository at this point in the history
  • Loading branch information
huydhn committed Aug 13, 2024
2 parents 5dfedea + 2654f59 commit acc6482
Show file tree
Hide file tree
Showing 29 changed files with 578 additions and 166 deletions.
4 changes: 2 additions & 2 deletions .ci/scripts/test_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ cleanup_files() {
prepare_artifacts_upload() {
if [ -n "$UPLOAD_DIR" ]; then
echo "Preparing for uploading generated artifacs"
zip -j model.zip "${EXPORTED_MODEL_NAME}" tokenizer.bin
mkdir -p "${UPLOAD_DIR}"
zip -j "model.zip" "${MODEL_NAME}" tokenizer.bin
cp "model.zip" "${UPLOAD_DIR}"
mv model.zip "${UPLOAD_DIR}"
fi
}

Expand Down
16 changes: 16 additions & 0 deletions .github/workflows/android-perf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,27 @@ jobs:
- name: Set parameters
id: set-parameters
shell: bash
env:
# Separate default values from the workflow dispatch. To ensure defaults are accessible
# during scheduled runs and to provide flexibility for different defaults between
# on-demand and periodic benchmarking.
CRON_DEFAULT_MODELS: "stories110M"
CRON_DEFAULT_DEVICES: "samsung_galaxy_s2x"
CRON_DEFAULT_DELEGATES: "xnnpack"
run: |
set -ex
MODELS="${{ inputs.models }}"
if [ -z "$MODELS" ]; then
MODELS="$CRON_DEFAULT_MODELS"
fi
DEVICES="${{ inputs.devices }}"
if [ -z "$DEVICES" ]; then
DEVICES="$CRON_DEFAULT_DEVICES"
fi
DELEGATES="${{ inputs.delegates }}"
if [ -z "$DELEGATES" ]; then
DELEGATES="$CRON_DEFAULT_DELEGATES"
fi
# Mapping devices to their corresponding device-pool-arn
declare -A DEVICE_POOL_ARNS
Expand Down
4 changes: 4 additions & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
exir_ops.edge.aten.copy.default,
]

to_be_implemented_operator = [
exir_ops.edge.aten.where.default,
]

allow_list_operator = [
_operator.getitem,
]
12 changes: 11 additions & 1 deletion backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupportBase

from .common_defs import allow_list_operator, not_supported_operator
from .common_defs import (
allow_list_operator,
not_supported_operator,
to_be_implemented_operator,
)


class QnnOperatorSupport(OperatorSupportBase):
Expand Down Expand Up @@ -62,6 +66,12 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
if node.op != "call_function" or node.target in not_supported_operator:
return False

if node.target in to_be_implemented_operator:
print(
f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped, this op can be supported, please report an issue in https://github.com/pytorch/executorch/issues"
)
return False

if node.target in allow_list_operator:
return True

Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/api/containers/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ class vTensor final {
return sizes_.size();
}

inline const std::vector<int64_t>& strides() const {
return strides_;
}

inline const std::vector<int64_t>& unsqueezed_strides() const {
return unsqueezed_strides_;
}

/*
* Returns a GPU buffer containing the sizes of the tensor in WHCN order.
* Note that dimensions that are not present in the tensor's sizes are set to
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#version 450 core

#define PRECISION ${PRECISION}
Expand Down
35 changes: 35 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#version 450 core

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}

#include "indexing_utils.h"

${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(0, "w", "nchw_buf", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_ubo(2, "ivec4", "in_sizes")}
${layout_declare_ubo(3, "ivec4", "in_strides")}
${layout_declare_ubo(4, "int", "numel")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// This constant is unused in this shader but is kept so that the signature is
// consistent with image_to_nchw.
layout(constant_id = 3) const int UNUSED_packed_dim = W_DIM;

void main() {
int out_id = int(gl_GlobalInvocationID.x);
if (out_id >= numel) {
return;
}

ivec4 t_in_idx = from_nchw_buffer_i(out_id, in_sizes);
const int in_id = to_buffer_id(t_in_idx, in_strides);

nchw_buf[out_id] = t_in[in_id];
}
18 changes: 18 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

buffer_to_nchw:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int
- VALUE: int8
shader_variants:
- NAME: buffer_to_nchw
55 changes: 46 additions & 9 deletions backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,21 @@
*/
#define alignup4(x) ((x + 3) & -4)

/*
* Input: (W, H, C, N) strides of a tensor
* Returns: the WHCN index of the fastest moving dimension
*/
int find_packed_dim(const ivec4 strides) {
int packed_dim = 0;
for (int i = 0; i <= 3; i++) {
if (strides[i] == 1) {
packed_dim = i;
break;
}
}
return packed_dim;
}

//
// (w, h, c, n) Tensor Index <-> Contiguous Buffer Index Conversion
//
Expand Down Expand Up @@ -74,27 +89,49 @@ ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) {
(buf_i / (sizes.x * sizes.y * sizes.z)));
}

int to_nchw_buffer_i(const ivec4 tensor_idx, const ivec4 sizes) {
return tensor_idx.w * sizes.x * sizes.y * sizes.z +
tensor_idx.z * sizes.x * sizes.y + tensor_idx.y * sizes.x + tensor_idx.x;
}

/*
* Input: Texel buffer index, (W, H, C, N) strides of a tensor, which dim is
* packed along a texel
* Returns: The (x, y, z, n) texel position corresponding to the first element
* of the texel at the specified buffer index
* Returns: The (w, h, c, n) tensor index corresponding to the buffer element
*/
ivec4 to_tensor_idx(int buf_i, ivec4 strides, int packed_dim) {
ivec4 to_tensor_idx(int buffer_id, const ivec4 strides, const int packed_dim) {
ivec4 idx;
for (int i = 3; i >= 0; i--) {
if (i != packed_dim) {
idx[i] = buf_i / strides[i];
buf_i %= strides[i];
idx[i] = buffer_id / strides[i];
buffer_id %= strides[i];
}
}
idx[packed_dim] = buf_i;
idx[packed_dim] = buffer_id;
return idx;
}

int to_texel_idx(const ivec4 texel_pos, ivec4 strides) {
return texel_pos.x * strides.x + texel_pos.y * strides.y +
texel_pos.z * strides.z + texel_pos.w * strides.w;
/*
* Input: Texel buffer index, (W, H, C, N) strides of a tensor
* Returns: The (w, h, c, n) tensor index corresponding to the buffer element
*
* This is a convenience overload of the above function. If the packed dim is
* not known, it can be found by finding the first dimension with a stride of 1.
* However, this process adds some overhead, so if performance is a concern then
* the above function should be used instead so that the packed dim is provided.
*/
ivec4 to_tensor_idx(int buffer_id, const ivec4 strides) {
int packed_dim = find_packed_dim(strides);
return to_tensor_idx(buffer_id, strides, packed_dim);
}

/*
* Input: (w, h, c, n) tensor index, (W, H, C, N) strides of the tensor buffer
* Returns: the buffer index corresponding to the specified tensor index
*/
int to_buffer_id(const ivec4 tensor_idx, ivec4 strides) {
return tensor_idx.x * strides.x + tensor_idx.y * strides.y +
tensor_idx.z * strides.z + tensor_idx.w * strides.w;
}

//
Expand Down
35 changes: 35 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#version 450 core

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}

#include "indexing_utils.h"

${define_required_extensions(DTYPE)}

layout(std430) buffer;

${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(1, "r", "nchw_in", DTYPE, STORAGE)}
${layout_declare_ubo(2, "ivec4", "out_sizes")}
${layout_declare_ubo(3, "ivec4", "out_strides")}
${layout_declare_ubo(4, "int", "numel")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

// This constant is unused in this shader but is kept so that the signature is
// consistent with nchw_to_image.
layout(constant_id = 3) const int UNUSED_packed_dim = W_DIM;

void main() {
int out_id = int(gl_GlobalInvocationID.x);
if (out_id >= numel) {
return;
}

ivec4 out_idx = to_tensor_idx(out_id, out_strides);
const int in_id = to_nchw_buffer_i(out_idx, out_sizes);

t_out[out_id] = nchw_in[in_id];
}
18 changes: 18 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

nchw_to_buffer:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int
- VALUE: int8
shader_variants:
- NAME: nchw_to_buffer
12 changes: 9 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ void add_staging_to_tensor_node(

vkapi::ParamsBindList ubos;
if (graph.is_buffer_storage(out_tensor)) {
ubos.append(graph.numel_ubo(out_tensor));
ubos.append(
{graph.sizes_ubo(out_tensor),
graph.strides_ubo(out_tensor),
graph.numel_ubo(out_tensor)});
} else {
ubos.append(graph.sizes_ubo(out_tensor));
}
Expand Down Expand Up @@ -61,7 +64,10 @@ void add_tensor_to_staging_node(

vkapi::ParamsBindList ubos;
if (graph.is_buffer_storage(in_tensor)) {
ubos.append(graph.numel_ubo(in_tensor));
ubos.append(
{graph.sizes_ubo(in_tensor),
graph.strides_ubo(in_tensor),
graph.numel_ubo(in_tensor)});
} else {
ubos.append(graph.sizes_ubo(in_tensor));
}
Expand Down Expand Up @@ -105,7 +111,7 @@ ValueRef prepack(

vkapi::ParamsBindList ubos;
if (graph.is_buffer_storage(v)) {
ubos.append(graph.numel_ubo(v));
ubos.append({graph.sizes_ubo(v), graph.strides_ubo(v), graph.numel_ubo(v)});
} else {
ubos.append(graph.sizes_ubo(v));
}
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader(
}

if (v_dst.storage_type() == utils::kBuffer) {
kernel_name = "buffer_to_buffer";
kernel_name = "nchw_to_buffer";
add_dtype_suffix(kernel_name, v_dst);
return VK_KERNEL_FROM_STR(kernel_name);
}
Expand All @@ -131,7 +131,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader(
}

if (v_src.storage_type() == utils::kBuffer) {
kernel_name = "buffer_to_buffer";
kernel_name = "buffer_to_nchw";
add_dtype_suffix(kernel_name, v_src);
return VK_KERNEL_FROM_STR(kernel_name);
}
Expand Down
13 changes: 6 additions & 7 deletions backends/vulkan/test/utils/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ void record_nchw_to_buffer_op(
vkapi::VulkanBuffer& src_buffer,
api::vTensor& v_dst) {
vkapi::PipelineBarrier pipeline_barrier{};
vkapi::SpecVarList specialization_constants = {
SV(v_dst.packed_dim_whcn_idx())};

context->submit_compute_job(
get_nchw_to_tensor_shader(v_dst),
pipeline_barrier,
{uint32_t(v_dst.numel()), 1, 1},
{64, 1, 1},
specialization_constants,
{},
VK_NULL_HANDLE,
0,
v_dst.buffer(
pipeline_barrier,
vkapi::PipelineStage::COMPUTE,
vkapi::MemoryAccessType::WRITE),
src_buffer,
v_dst.sizes_ubo(),
v_dst.strides_ubo(),
v_dst.numel_ubo());
}

Expand All @@ -47,19 +47,18 @@ void record_buffer_to_nchw_op(
api::vTensor& v_src,
vkapi::VulkanBuffer& dst_buffer) {
vkapi::PipelineBarrier pipeline_barrier{};
vkapi::SpecVarList specialization_constants = {
SV(v_src.packed_dim_whcn_idx())};

context->submit_compute_job(
get_tensor_to_nchw_shader(v_src),
pipeline_barrier,
{uint32_t(v_src.numel()), 1, 1},
{64, 1, 1},
specialization_constants,
{},
VK_NULL_HANDLE,
0,
dst_buffer,
v_src.buffer(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
v_src.sizes_ubo(),
v_src.strides_ubo(),
v_src.numel_ubo());
}

Expand Down
Loading

0 comments on commit acc6482

Please sign in to comment.