Skip to content

[ET-VK][Op Redesign][6/n] Merge ArithmeticPrepack into PrepackNode #2261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

namespace at {
Expand Down
55 changes: 55 additions & 0 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

namespace at {
namespace native {
namespace vulkan {

void PrepackNode::encode(ComputeGraph* graph) {
api::Context* const context = graph->context();
api::PipelineBarrier pipeline_barrier{};

TensorRef tref = graph->get_val(tref_).toTensorRef();
vTensor packed = graph->get_val(packed_).toTensor();

// TODO: Extract to standalone function, to support other types of prepacking.
api::StorageBuffer staging(
graph->context(), packed.dtype(), packed.gpu_nbytes());
size_t numel = api::utils::multiply_integers(tref.sizes);
size_t nbytes = numel * api::element_size(tref.dtype);
copy_ptr_to_staging(tref.data, staging, nbytes);

std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();

api::DescriptorSet descriptor_set =
context->get_descriptor_set(shader_, local_workgroup_size_);

uint32_t idx = 0;
bind_tensor_to_descriptor_set(
packed,
pipeline_barrier,
api::MemoryAccessType::WRITE,
descriptor_set,
idx++);
bind_staging_to_descriptor_set(staging, descriptor_set, idx++);
descriptor_set.bind(idx, params_.buffer());

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
}

} // namespace vulkan
} // namespace native
} // namespace at
33 changes: 25 additions & 8 deletions backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,37 @@ class ComputeGraph;
* encoding of shaders transferring necessary data (such as weights and biases)
* to the GPU.
*/
class PrepackNode {
class PrepackNode final {
friend class ComputeGraph;

public:
PrepackNode(ValueRef tref, ValueRef packed) : tref_{tref}, packed_{packed} {}
PrepackNode(
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
api::UniformParamsBuffer&& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
tref_(tref),
packed_(packed),
params_(std::move(params)) {}

virtual ~PrepackNode() = default;
~PrepackNode() = default;

protected:
ValueRef tref_;
ValueRef packed_;
void encode(ComputeGraph* graph);

public:
virtual void encode(ComputeGraph* graph) const = 0;
protected:
const api::ShaderInfo shader_;
const api::utils::uvec3 global_workgroup_size_;
const api::utils::uvec3 local_workgroup_size_;
const ValueRef tref_;
const ValueRef packed_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;
};

} // namespace vulkan
Expand Down
174 changes: 174 additions & 0 deletions backends/vulkan/runtime/graph/ops/StagingUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

#include <ATen/native/vulkan/impl/Common.h>

namespace at {
namespace native {
namespace vulkan {

void memcpy_to_mapping(
const void* src,
api::MemoryMap& dst_mapping,
const size_t nbytes,
const api::ScalarType dtype) {
#define DTYPE_CASE(ctype, vkformat, name) \
case api::ScalarType::name: \
memcpy_to_mapping_impl<ctype>(src, dst_mapping, nbytes); \
break;

switch (dtype) {
VK_FORALL_SCALAR_TYPES(DTYPE_CASE)
default:
VK_THROW("Unrecognized dtype!");
}
#undef DTYPE_CASE
}

void memcpy_from_mapping(
api::MemoryMap& src_mapping,
void* dst,
const size_t nbytes,
const api::ScalarType dtype) {
#define DTYPE_CASE(ctype, vkformat, name) \
case api::ScalarType::name: \
memcpy_from_mapping_impl<ctype>(src_mapping, dst, nbytes); \
break;

switch (dtype) {
VK_FORALL_SCALAR_TYPES(DTYPE_CASE)
default:
VK_THROW("Unrecognized dtype!");
}
#undef DTYPE_CASE
}

void copy_ptr_to_staging(
const void* src,
api::StorageBuffer& staging,
const size_t nbytes) {
api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
mapping.invalidate();
memcpy_to_mapping(src, mapping, nbytes, staging.dtype());
}

void copy_staging_to_ptr(
api::StorageBuffer& staging,
void* dst,
const size_t nbytes) {
api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::READ);
mapping.invalidate();
memcpy_from_mapping(mapping, dst, nbytes, staging.dtype());
}

api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
if (v_dst.is_quantized()) {
switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
switch (v_dst.dtype()) {
case api::ScalarType::QUInt8:
return VK_KERNEL(nchw_to_image_uint8);
case api::ScalarType::QInt8:
return VK_KERNEL(nchw_to_image_int8);
case api::ScalarType::QInt32:
return VK_KERNEL(nchw_to_image_int32);
default:
VK_THROW(
"Vulkan quantization currently not supported for dtype ",
v_dst.dtype());
}
case api::StorageType::TEXTURE_2D:
switch (v_dst.dtype()) {
case api::ScalarType::QUInt8:
return VK_KERNEL(nchw_to_image2d_uint8);
case api::ScalarType::QInt8:
return VK_KERNEL(nchw_to_image2d_int8);
case api::ScalarType::QInt32:
return VK_KERNEL(nchw_to_image2d_int32);
default:
VK_THROW(
"Vulkan quantization currently not supported for dtype ",
v_dst.dtype());
}
default:
VK_THROW("No kernel available!");
case api::StorageType::BUFFER:
case api::StorageType::UNKNOWN:
VK_THROW("Requested storage type must be a texture type.");
}
}

if (v_dst.dtype() == api::kFloat) {
switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(nchw_to_image2d);
default:
VK_THROW("No kernel available!");
}
} else if (v_dst.dtype() == api::kBool) {
switch (v_dst.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(nchw_to_image_bool);
default:
VK_THROW("No kernel available!");
}
} else {
VK_THROW("Unsupported dtype!");
}
}

api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) {
if (v_src.is_quantized() || v_src.dtype() == api::kBool) {
auto plane_size =
dim_at<Dim4D::Height>(v_src) * dim_at<Dim4D::Width>(v_src);
switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
switch (v_src.dtype()) {
case api::ScalarType::QUInt8:
case api::ScalarType::QInt8:
case api::kBool:
return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4)
: VK_KERNEL(image_to_nchw_uint);
case api::ScalarType::QInt32:
return VK_KERNEL(image_to_nchw_int32);
default:
VK_THROW(
"Vulkan quantization currently not supported for dtype ",
v_src.dtype());
}
default:
VK_THROW("No kernel available!");
case api::StorageType::BUFFER:
case api::StorageType::UNKNOWN:
VK_THROW("Requested storage type must be a texture type.");
}
}

if (v_src.dtype() == api::kFloat) {
switch (v_src.storage_type()) {
case api::StorageType::TEXTURE_3D:
return VK_KERNEL(image_to_nchw);
case api::StorageType::TEXTURE_2D:
return VK_KERNEL(image2d_to_nchw);
default:
VK_THROW("No kernel available!");
}
} else {
VK_THROW("Unsupported dtype!");
}
}

} // namespace vulkan
} // namespace native
} // namespace at
82 changes: 82 additions & 0 deletions backends/vulkan/runtime/graph/ops/StagingUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef USE_VULKAN_API

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <cstring>

namespace at {
namespace native {
namespace vulkan {

//
// Functions to memcpy data into staging buffer
//

void memcpy_to_mapping(
const void* src,
api::MemoryMap& dst_mapping,
const size_t nbytes,
const api::ScalarType dtype);
void memcpy_from_mapping(
const api::MemoryMap& src_mapping,
void* dst,
const size_t nbytes,
const api::ScalarType dtype);

//
// Utility functions for memcpy
//

template <typename T>
void memcpy_to_mapping_impl(
const void* src,
api::MemoryMap& dst_mapping,
const size_t nbytes) {
T* data_ptr = dst_mapping.template data<T>();
memcpy(data_ptr, reinterpret_cast<const T*>(src), nbytes);
}

template <typename T>
void memcpy_from_mapping_impl(
api::MemoryMap& src_mapping,
void* dst,
const size_t nbytes) {
T* data_ptr = src_mapping.template data<T>();
memcpy(reinterpret_cast<T*>(dst), data_ptr, nbytes);
}

//
// Functions to copy data into and out of a staging buffer
//

void copy_ptr_to_staging(
const void* src,
api::StorageBuffer& staging,
const size_t nbytes);
void copy_staging_to_ptr(
api::StorageBuffer& staging,
void* dst,
const size_t nbytes);

//
// Functions to get shaders
//

api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst);
api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src);

} // namespace vulkan
} // namespace native
} // namespace at

#endif /* USE_VULKAN_API */
Loading