Skip to content

Commit 4ddd063

Browse files
committed
[ET-VK] Introduce ParamsBindList to prevent needing to pass shared_ptr to bind parameter UBOs
Differential Revision: [D56357188](https://our.internmc.facebook.com/intern/diff/D56357188/) ghstack-source-id: 223199138 Pull Request resolved: #3150
1 parent ea4931e commit 4ddd063

16 files changed

+138
-52
lines changed

backends/vulkan/runtime/api/Context.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,5 +235,28 @@ UniformParamsBuffer& UniformParamsBuffer::operator=(
235235
return *this;
236236
}
237237

238+
ParamsBindList::ParamsBindList(
239+
std::initializer_list<const api::BufferBindInfo> init_list) {
240+
bind_infos.resize(init_list.size());
241+
std::copy(init_list.begin(), init_list.end(), bind_infos.begin());
242+
}
243+
244+
ParamsBindList::ParamsBindList(
245+
std::initializer_list<const api::UniformParamsBuffer*> init_list) {
246+
bind_infos.resize(init_list.size());
247+
for (int i = 0; i < init_list.size(); ++i) {
248+
bind_infos[i] = api::BufferBindInfo(init_list.begin()[i]->buffer());
249+
}
250+
}
251+
252+
ParamsBindList::ParamsBindList(
253+
std::initializer_list<std::shared_ptr<api::UniformParamsBuffer>>
254+
init_list) {
255+
bind_infos.resize(init_list.size());
256+
for (int i = 0; i < init_list.size(); ++i) {
257+
bind_infos[i] = api::BufferBindInfo(init_list.begin()[i]->buffer());
258+
}
259+
}
260+
238261
} // namespace api
239262
} // namespace vkcompute

backends/vulkan/runtime/api/Context.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ class UniformParamsBuffer final {
244244
}
245245
}
246246

247-
VulkanBuffer& buffer() {
247+
const VulkanBuffer& buffer() const {
248248
return vulkan_buffer_;
249249
}
250250

@@ -264,6 +264,19 @@ class UniformParamsBuffer final {
264264
}
265265
};
266266

267+
struct ParamsBindList final {
268+
std::vector<api::BufferBindInfo> bind_infos;
269+
270+
ParamsBindList(std::initializer_list<const api::BufferBindInfo> init_list);
271+
272+
ParamsBindList(
273+
std::initializer_list<const api::UniformParamsBuffer*> init_list);
274+
275+
ParamsBindList(
276+
std::initializer_list<std::shared_ptr<api::UniformParamsBuffer>>
277+
init_list);
278+
};
279+
267280
class StorageBuffer final {
268281
private:
269282
Context* context_p_;
@@ -331,6 +344,11 @@ inline void arg_is_empty(bool& any_is_empty, const VulkanImage& image) {
331344
any_is_empty = any_is_empty || !image;
332345
}
333346

347+
inline void arg_is_empty(bool& any_is_empty, const BufferBindInfo& bind_info) {
348+
// bool(image) will evaluate to false if no memory has been allocated
349+
any_is_empty = any_is_empty || (bind_info.handle == VK_NULL_HANDLE);
350+
}
351+
334352
/*
335353
Reports if any VulkanBuffer or VulkanImage argument in a variadic argument
336354
list does not have any memory associated with it.

backends/vulkan/runtime/api/Descriptor.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@
1515
namespace vkcompute {
1616
namespace api {
1717

18+
//
19+
// BufferBinding
20+
//
21+
22+
BufferBindInfo::BufferBindInfo()
23+
: handle(VK_NULL_HANDLE), offset(0u), range(0u) {}
24+
25+
BufferBindInfo::BufferBindInfo(const VulkanBuffer& buffer_p)
26+
: handle(buffer_p.handle()),
27+
offset(buffer_p.mem_offset()),
28+
range(buffer_p.mem_range()) {}
29+
1830
//
1931
// DescriptorSet
2032
//
@@ -66,6 +78,21 @@ DescriptorSet& DescriptorSet::bind(
6678
return *this;
6779
}
6880

81+
DescriptorSet& DescriptorSet::bind(
82+
const uint32_t idx,
83+
const BufferBindInfo& bind_info) {
84+
DescriptorSet::ResourceBinding binder{};
85+
binder.binding_idx = idx; // binding_idx
86+
binder.descriptor_type = shader_layout_signature_[idx]; // descriptor_type
87+
binder.is_image = false; // is_image
88+
binder.resource_info.buffer_info.buffer = bind_info.handle; // buffer
89+
binder.resource_info.buffer_info.offset = bind_info.offset; // offset
90+
binder.resource_info.buffer_info.range = bind_info.range; // range
91+
add_binding(binder);
92+
93+
return *this;
94+
}
95+
6996
DescriptorSet& DescriptorSet::bind(
7097
const uint32_t idx,
7198
const VulkanImage& image) {

backends/vulkan/runtime/api/Descriptor.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@
2020
namespace vkcompute {
2121
namespace api {
2222

23+
/*
24+
* Stores the binding information of a Vulkan Buffer so that the buffer can be
25+
* bound at a later time. This struct should only be used if the buffer to be
26+
* bound is guaranteed to be active at the time of binding.
27+
*/
28+
struct BufferBindInfo final {
29+
VkBuffer handle;
30+
VkDeviceSize offset;
31+
VkDeviceSize range;
32+
33+
BufferBindInfo();
34+
BufferBindInfo(const VulkanBuffer& buffer_p);
35+
};
36+
2337
class DescriptorSet final {
2438
public:
2539
explicit DescriptorSet(VkDevice, VkDescriptorSet, ShaderLayout::Signature);
@@ -50,6 +64,7 @@ class DescriptorSet final {
5064
std::vector<ResourceBinding> bindings_;
5165

5266
public:
67+
DescriptorSet& bind(const uint32_t, const BufferBindInfo&);
5368
DescriptorSet& bind(const uint32_t, const VulkanBuffer&);
5469
DescriptorSet& bind(const uint32_t, const VulkanImage&);
5570

backends/vulkan/runtime/api/Tensor.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ vTensor::vTensor(
140140
sizes_(sizes.begin(), sizes.end()),
141141
gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)},
142142
// Utility Uniform Buffers that can be passed to shaders as arguments
143-
cpu_sizes_uniform_(nullptr),
144-
gpu_sizes_uniform_(nullptr),
145-
extents_uniform_(nullptr),
143+
cpu_sizes_uniform_(),
144+
gpu_sizes_uniform_(),
145+
extents_uniform_(),
146146
// Construct Tensor storage
147147
storage_(
148148
context,
@@ -189,33 +189,33 @@ api::VulkanBuffer& vTensor::buffer(
189189
return storage_.buffer_;
190190
}
191191

192-
std::shared_ptr<api::UniformParamsBuffer> vTensor::cpu_sizes_ubo() {
193-
if (!cpu_sizes_uniform_) {
194-
cpu_sizes_uniform_.reset(new api::UniformParamsBuffer(
195-
storage_.context_, api::utils::make_whcn_ivec4(sizes_)));
192+
const api::BufferBindInfo vTensor::cpu_sizes_ubo() {
193+
if (!cpu_sizes_uniform_.buffer()) {
194+
cpu_sizes_uniform_ = api::UniformParamsBuffer(
195+
storage_.context_, api::utils::make_whcn_ivec4(sizes_));
196196
}
197-
return cpu_sizes_uniform_;
197+
return api::BufferBindInfo(cpu_sizes_uniform_.buffer());
198198
}
199199

200-
std::shared_ptr<api::UniformParamsBuffer> vTensor::gpu_sizes_ubo() {
201-
if (!gpu_sizes_uniform_) {
202-
gpu_sizes_uniform_.reset(new api::UniformParamsBuffer(
203-
storage_.context_, api::utils::make_whcn_ivec4(gpu_sizes_)));
200+
const api::BufferBindInfo vTensor::gpu_sizes_ubo() {
201+
if (!gpu_sizes_uniform_.buffer()) {
202+
gpu_sizes_uniform_ = api::UniformParamsBuffer(
203+
storage_.context_, api::utils::make_whcn_ivec4(gpu_sizes_));
204204
}
205-
return gpu_sizes_uniform_;
205+
return api::BufferBindInfo(gpu_sizes_uniform_.buffer());
206206
}
207207

208-
std::shared_ptr<api::UniformParamsBuffer> vTensor::extents_ubo() {
209-
if (!extents_uniform_) {
210-
extents_uniform_.reset(new api::UniformParamsBuffer(
208+
const api::BufferBindInfo vTensor::extents_ubo() {
209+
if (!extents_uniform_.buffer()) {
210+
extents_uniform_ = api::UniformParamsBuffer(
211211
storage_.context_,
212212
api::utils::uvec4(
213213
{storage_.extents_.data[0],
214214
storage_.extents_.data[1],
215215
storage_.extents_.data[2],
216-
1u})));
216+
1u}));
217217
}
218-
return extents_uniform_;
218+
return api::BufferBindInfo(extents_uniform_.buffer());
219219
}
220220

221221
VmaAllocationCreateInfo vTensor::get_allocation_create_info() const {
@@ -258,16 +258,16 @@ void vTensor::update_size_metadata(const std::vector<int64_t>& new_sizes) {
258258
api::utils::uvec3 virtual_extents =
259259
create_image_extents(gpu_sizes_, storage_type(), memory_layout_);
260260

261-
if (cpu_sizes_uniform_) {
262-
cpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(sizes_));
261+
if (cpu_sizes_uniform_.buffer()) {
262+
cpu_sizes_uniform_.update(api::utils::make_whcn_ivec4(sizes_));
263263
}
264264

265-
if (gpu_sizes_uniform_) {
266-
gpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(gpu_sizes_));
265+
if (gpu_sizes_uniform_.buffer()) {
266+
gpu_sizes_uniform_.update(api::utils::make_whcn_ivec4(gpu_sizes_));
267267
}
268268

269-
if (extents_uniform_) {
270-
extents_uniform_->update(api::utils::uvec4(
269+
if (extents_uniform_.buffer()) {
270+
extents_uniform_.update(api::utils::uvec4(
271271
{virtual_extents.data[0],
272272
virtual_extents.data[1],
273273
virtual_extents.data[2],

backends/vulkan/runtime/api/Tensor.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,17 @@ class vTensor final {
118118

119119
// A Vulkan uniform buffer containing the tensor sizes in WHCN that can be
120120
// passed into a shader.
121-
std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_uniform_;
121+
api::UniformParamsBuffer cpu_sizes_uniform_;
122122

123123
// A Vulkan uniform buffer containing the GPU tensor sizes in WHCN that can
124124
// be passed into a shader. GPU sizes refers to the sizes of the tensor after
125125
// padding has been applied to one dimension to align it to the next multiple
126126
// of 4.
127-
std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_uniform_;
127+
api::UniformParamsBuffer gpu_sizes_uniform_;
128128

129129
// A Vulkan uniform buffer containing the image extents of the underlying
130130
// image texture that can be passed into a shader.
131-
std::shared_ptr<api::UniformParamsBuffer> extents_uniform_;
131+
api::UniformParamsBuffer extents_uniform_;
132132

133133
// Store the backing storage of the tensor as a shared pointer to allow two
134134
// tensors to share the same underlying resource, but with different metadata.
@@ -210,21 +210,21 @@ class vTensor final {
210210
* shader. Note that the UBO will be created the first time this function is
211211
* called.
212212
*/
213-
std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_ubo();
213+
const api::BufferBindInfo cpu_sizes_ubo();
214214

215215
/*
216216
* Get a uniform buffer object containing the tensor GPU sizes to use in a
217217
* compute shader. Note that the UBO will be created the first time this
218218
* function is called.
219219
*/
220-
std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_ubo();
220+
const api::BufferBindInfo gpu_sizes_ubo();
221221

222222
/*
223223
* Get a uniform buffer object containing the image extents to use in a
224224
* compute shader. Note that the UBO will be created the first time this
225225
* function is called.
226226
*/
227-
std::shared_ptr<api::UniformParamsBuffer> extents_ubo();
227+
const api::BufferBindInfo extents_ubo();
228228

229229
inline size_t numel() const {
230230
return api::utils::multiply_integers(sizes());

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ ComputeGraph::ComputeGraph(GraphConfig config)
5959
config_.contextConfig)},
6060
shared_objects_{},
6161
values_{},
62+
param_ubos_{},
6263
prepack_nodes_{},
6364
execute_nodes_{},
6465
inputs_{},

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class ComputeGraph final {
9393
std::unique_ptr<api::Context> context_;
9494
std::vector<SharedObject> shared_objects_;
9595
std::vector<Value> values_;
96+
std::vector<api::UniformParamsBuffer> param_ubos_;
9697

9798
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
9899
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;
@@ -314,9 +315,9 @@ class ComputeGraph final {
314315
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
315316

316317
template <typename Block>
317-
inline std::shared_ptr<api::UniformParamsBuffer> create_params_buffer(
318-
const Block& data) {
319-
return std::make_shared<api::UniformParamsBuffer>(context_.get(), data);
318+
const api::BufferBindInfo create_params_buffer(const Block& data) {
319+
param_ubos_.emplace_back(api::UniformParamsBuffer(context_.get(), data));
320+
return api::BufferBindInfo(param_ubos_.back().buffer());
320321
}
321322

322323
/*

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ ExecuteNode::ExecuteNode(
2020
const api::utils::uvec3& global_workgroup_size,
2121
const api::utils::uvec3& local_workgroup_size,
2222
const std::vector<ArgGroup>& args,
23-
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
23+
const api::ParamsBindList& params,
2424
const ResizeFunction& resize_fn,
2525
const std::vector<ValueRef>& resize_args,
2626
const api::SpecVarList& spec_vars)
@@ -47,6 +47,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
4747
uint32_t idx = 0;
4848
idx = bind_values_to_descriptor_set(
4949
graph, args_, pipeline_barrier, descriptor_set, idx);
50+
5051
bind_params_to_descriptor_set(params_, descriptor_set, idx);
5152

5253
context->register_shader_dispatch(

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class ExecuteNode final {
5454
const api::utils::uvec3& global_workgroup_size,
5555
const api::utils::uvec3& local_workgroup_size,
5656
const std::vector<ArgGroup>& args,
57-
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
57+
const api::ParamsBindList& params,
5858
const ResizeFunction& resize_fn = nullptr,
5959
const std::vector<ValueRef>& resize_args = {},
6060
const api::SpecVarList& spec_vars = {});
@@ -74,7 +74,7 @@ class ExecuteNode final {
7474
const api::utils::uvec3 global_workgroup_size_;
7575
const api::utils::uvec3 local_workgroup_size_;
7676
const std::vector<ArgGroup> args_;
77-
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
77+
const api::ParamsBindList params_;
7878
const ResizeFunction resize_fn_;
7979
const std::vector<ValueRef> resize_args_;
8080
const api::SpecVarList spec_vars_;

0 commit comments

Comments
 (0)