Skip to content

[ET-VK] Support multiple UniformParamsBuffer #2348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ class ComputeGraph final {
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);

template <typename Block>
inline std::shared_ptr<api::UniformParamsBuffer> create_params_buffer(
const Block& data) {
return std::make_shared<api::UniformParamsBuffer>(context_.get(), data);
}

/*
* Convenience function to add an input tensor along with its staging buffer
*/
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ ExecuteNode::ExecuteNode(
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params)
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(std::move(params)) {
params_(params) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

Expand All @@ -43,7 +43,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
uint32_t idx = 0;
idx = bind_values_to_descriptor_set(
graph, args_, pipeline_barrier, descriptor_set, idx);
descriptor_set.bind(idx, params_.buffer());
bind_params_to_descriptor_set(params_, descriptor_set, idx);

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ExecuteNode final {
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
api::UniformParamsBuffer&& params);
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);

~ExecuteNode() = default;

Expand All @@ -64,9 +64,8 @@ class ExecuteNode final {
const api::utils::uvec3 global_workgroup_size_;
const api::utils::uvec3 local_workgroup_size_;
const std::vector<ArgGroup> args_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
};

} // namespace vulkan
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ PrepackNode::PrepackNode(
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
api::UniformParamsBuffer&& params)
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
tref_(tref),
packed_(packed),
params_(std::move(params)) {
params_(params) {
graph.update_descriptor_counts(shader, /*execute = */ false);
}

Expand Down Expand Up @@ -61,7 +61,7 @@ void PrepackNode::encode(ComputeGraph* graph) {
descriptor_set,
idx++);
bind_staging_to_descriptor_set(staging, descriptor_set, idx++);
descriptor_set.bind(idx, params_.buffer());
bind_params_to_descriptor_set(params_, descriptor_set, idx);

context->register_shader_dispatch(
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PrepackNode final {
const api::utils::uvec3& local_workgroup_size,
const ValueRef tref,
const ValueRef packed,
api::UniformParamsBuffer&& params);
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);

~PrepackNode() = default;

Expand All @@ -49,9 +49,8 @@ class PrepackNode final {
const api::utils::uvec3 local_workgroup_size_;
const ValueRef tref_;
const ValueRef packed_;
// TODO(T180906086): pass multiple buffers and index with ValueRef.
// TODO(T180906457): allow re-computing param buffers.
api::UniformParamsBuffer params_;
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
};

} // namespace vulkan
Expand Down
3 changes: 1 addition & 2 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ void add_arithmetic_node(
get_size_as_ivec4(t_in2),
alpha_val,
};
api::UniformParamsBuffer params(graph.context(), block);

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
Expand All @@ -81,7 +80,7 @@ void add_arithmetic_node(
local_size,
{{out, api::MemoryAccessType::WRITE},
{{arg1, arg2}, api::MemoryAccessType::READ}},
std::move(params)));
{graph.create_params_buffer(block)}));
}

REGISTER_OPERATORS {
Expand Down
17 changes: 9 additions & 8 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,14 @@ void add_staging_to_tensor_node(
api::utils::uvec3 global_size = t_out.extents();
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

api::UniformParamsBuffer params(
graph.context(), create_staging_params(t_out));

graph.execute_nodes().emplace_back(new ExecuteNode(
graph,
shader,
global_size,
local_size,
{{out_tensor, api::MemoryAccessType::WRITE},
{in_staging, api::MemoryAccessType::READ}},
std::move(params)));
{graph.create_params_buffer(create_staging_params(t_out))}));
}

void add_tensor_to_staging_node(
Expand All @@ -71,7 +68,6 @@ void add_tensor_to_staging_node(
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

StagingParams sp = create_staging_params(t_in);
api::UniformParamsBuffer params(graph.context(), sp);

// TODO(T181194784): These are workgroup sizes for special cases. Refactor the
// calculation of workgroup sizes to a standalone function. We should use
Expand All @@ -98,7 +94,7 @@ void add_tensor_to_staging_node(
local_size,
{{in_tensor, api::MemoryAccessType::READ},
{out_staging, api::MemoryAccessType::WRITE}},
std::move(params)));
{graph.create_params_buffer(sp)}));
}

ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
Expand All @@ -112,10 +108,15 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);

StagingParams sp = create_staging_params(t);
api::UniformParamsBuffer params(graph.context(), sp);

graph.prepack_nodes().emplace_back(new PrepackNode(
graph, shader, global_size, local_size, vref, v, std::move(params)));
graph,
shader,
global_size,
local_size,
vref,
v,
{graph.create_params_buffer(sp)}));

return v;
}
Expand Down
25 changes: 18 additions & 7 deletions backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ void bind_tensor_to_descriptor_set(
}
}

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx) {
descriptor_set.bind(idx, staging.buffer());
}

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
Expand Down Expand Up @@ -63,6 +56,24 @@ uint32_t bind_values_to_descriptor_set(
return idx;
}

uint32_t bind_params_to_descriptor_set(
std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx) {
uint32_t idx = base_idx;
for (auto& param : params) {
descriptor_set.bind(idx++, param->buffer());
}
return idx;
}

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx) {
descriptor_set.bind(idx, staging.buffer());
}

} // namespace vulkan
} // namespace native
} // namespace at
23 changes: 18 additions & 5 deletions backends/vulkan/runtime/graph/ops/utils/BindingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,38 @@ namespace at {
namespace native {
namespace vulkan {

//
// For objects in the graph
//

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
const api::MemoryAccessType accessType,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

uint32_t bind_values_to_descriptor_set(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
api::PipelineBarrier& pipeline_barrier,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

//
// For objects NOT in the graph
//

uint32_t bind_params_to_descriptor_set(
std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
api::DescriptorSet& descriptor_set,
const uint32_t base_idx);

void bind_staging_to_descriptor_set(
api::StorageBuffer& staging,
api::DescriptorSet& descriptor_set,
const uint32_t idx);

} // namespace vulkan
} // namespace native
} // namespace at
Expand Down