Skip to content

Commit

Permalink
Add support of PushConstants in Vulkan ProgramBindings via ProgramArg…
Browse files Browse the repository at this point in the history
…umentValueType::RootConstantValue
  • Loading branch information
egorodet committed Oct 26, 2024
1 parent 1ec3a62 commit 2756175
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 58 deletions.
1 change: 1 addition & 0 deletions Apps/08-ConsoleCompute/Shaders/GameOfLife.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Compute shader for Conway's Game of Life

#include "GameOfLifeRules.h"

[[vk::push_constant]]
ConstantBuffer<Constants> g_constants : register(b0, META_ARG_CONSTANT);
RWTexture2D<uint> g_frame_texture : register(u0, META_ARG_MUTABLE);

Expand Down
8 changes: 1 addition & 7 deletions Apps/08-ConsoleCompute/Shaders/GameOfLifeRules.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ Game of Life Rules constants
#include <vector>
#include <string>

#define HLSL_ALIGN alignas(16)

// Game of Life alternative rules:
// https://conwaylife.com/wiki/List_of_Life-like_rules
// NOTE: indices of rules should match constants in "Shaders/GameOfLifeRules.h"
Expand All @@ -43,10 +41,6 @@ static const std::vector<std::string> g_gol_rule_labels{

using uint = uint32_t;

#else

#define HLSL_ALIGN

#endif

static const uint g_game_rule_classic = 0; // B3/S23
Expand All @@ -56,7 +50,7 @@ static const uint g_game_rule_coral = 3; // B3/S45678
static const uint g_game_rule_geology = 4; // B3578/S24678
static const uint g_game_rule_vote = 5; // B5678/S45678

struct HLSL_ALIGN Constants
struct Constants
{
uint game_rule_id;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ProgramArgumentBinding
void SetEmitCallbackEnabled(bool enabled) { m_emit_callback_enabled = enabled; }

Ptr<ProgramArgumentBinding> GetPtr() { return shared_from_this(); }
RootConstantAccessor* GetRootConstantAccessorPtr() const { return m_root_constant_accessor_ptr.get(); }

void Initialize(Program& program, Data::Index frame_index);
bool IsAlreadyApplied(const Rhi::IProgram& program,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class RootConstantAccessor
bool IsInitialized() const noexcept { return m_is_initialized; }
const Range& GetBufferRange() const noexcept { return m_buffer_range; }
Data::Size GetDataSize() const noexcept { return m_data_size; }
Data::Byte* GetDataPtr();
Rhi::ResourceView GetResourceView() const;
RootConstantStorage& GetRootConstantBuffer() const { return m_storage_ref.get(); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,18 @@ ProgramArgumentBinding::operator std::string() const
void ProgramArgumentBinding::Initialize(Program& program, Data::Index frame_index)
{
META_FUNCTION_TASK();
if (m_settings.argument.IsRootConstant() && !m_root_constant_accessor_ptr)
if (!m_settings.argument.IsRootConstant() || m_root_constant_accessor_ptr)
return;

if (m_settings.argument.IsRootConstantValue())
{
m_root_constant_accessor_ptr = program.GetRootConstantStorage().ReserveRootConstant(m_settings.buffer_size);
}
else
{
if (m_settings.argument.IsRootConstantValue())
{
m_root_constant_accessor_ptr = program.GetRootConstantStorage().ReserveRootConstant(m_settings.buffer_size);
}
else
{
RootConstantBuffer& root_constant_buffer = program.GetRootConstantBuffer(m_settings.argument.GetAccessorType(), frame_index);
m_root_constant_accessor_ptr = root_constant_buffer.ReserveRootConstant(m_settings.buffer_size);
root_constant_buffer.Connect(*this);
}
RootConstantBuffer& root_constant_buffer = program.GetRootConstantBuffer(m_settings.argument.GetAccessorType(), frame_index);
m_root_constant_accessor_ptr = root_constant_buffer.ReserveRootConstant(m_settings.buffer_size);
root_constant_buffer.Connect(*this);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ Rhi::ResourceView RootConstantAccessor::GetResourceView() const
return root_constant_buffer.GetResourceView(m_buffer_range.GetStart(), m_data_size);
}

Data::Byte* RootConstantAccessor::GetDataPtr()
{
META_FUNCTION_TASK();
return m_storage_ref.get().GetData().data() + m_buffer_range.GetStart();
}

//////////////////// RootConstantStorage ////////////////////

RootConstantStorage::~RootConstantStorage()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Program final
std::vector<vk::PipelineShaderStageCreateInfo> GetNativeShaderStageCreateInfos() const;
vk::PipelineVertexInputStateCreateInfo GetNativeVertexInputStateCreateInfo() const;
const std::vector<vk::DescriptorSetLayout>& GetNativeDescriptorSetLayouts() const;
const std::vector<vk::PushConstantRange>& GetNativePushConstantRanges() const;
const vk::DescriptorSetLayout& GetNativeDescriptorSetLayout(ArgumentAccessor::Type argument_access_type) const;
const DescriptorSetLayoutInfo& GetDescriptorSetLayoutInfo(ArgumentAccessor::Type argument_access_type) const;
const vk::PipelineLayout& GetNativePipelineLayout() const;
Expand All @@ -90,6 +91,7 @@ class Program final
DescriptorSetLayoutInfoByAccessType m_descriptor_set_layout_info_by_access_type;
std::vector<vk::UniqueDescriptorSetLayout> m_vk_unique_descriptor_set_layouts;
std::vector<vk::DescriptorSetLayout> m_vk_descriptor_set_layouts;
std::vector<vk::PushConstantRange> m_vk_push_constant_ranges;
vk::UniquePipelineLayout m_vk_unique_pipeline_layout;
std::optional<vk::DescriptorSet> m_vk_constant_descriptor_set_opt;
std::vector<vk::DescriptorSet> m_vk_frame_constant_descriptor_sets;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ class ProgramArgumentBinding final
ProgramArgumentBinding(const ProgramArgumentBinding& other) = default;

const Settings& GetVulkanSettings() const noexcept { return m_settings_vk; }
uint32_t GetPushConstantsOffset() const noexcept { return m_vk_push_constants_offset; }
vk::ShaderStageFlagBits GetNativeShaderStageFlags() const;

void SetDescriptorSetBinding(const vk::DescriptorSet& descriptor_set, uint32_t layout_binding_index) noexcept;
void SetDescriptorSet(const vk::DescriptorSet& descriptor_set) noexcept;
void SetPushConstantsOffset(uint32_t push_constant_offset) noexcept;

// Base::ProgramArgumentBinding interface
[[nodiscard]] Ptr<Base::ProgramArgumentBinding> CreateCopy() const override;
Expand All @@ -82,7 +85,8 @@ class ProgramArgumentBinding final

Settings m_settings_vk;
vk::DescriptorSet m_vk_descriptor_set;
uint32_t m_vk_binding_value = 0U;
uint32_t m_vk_binding_value = 0U;
uint32_t m_vk_push_constants_offset = 0U;
vk::WriteDescriptorSet m_vk_write_descriptor_set;
std::vector<vk::DescriptorImageInfo> m_vk_descriptor_images;
std::vector<vk::DescriptorBufferInfo> m_vk_descriptor_buffers;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,22 @@ class ProgramBindings final
void UpdateDynamicDescriptorOffsets();
void UpdateMutableDescriptorSetName();

struct PushConstantSetter
{
Rhi::ProgramArgumentAccessType access_type;
vk::ShaderStageFlags shader_stages{};
uint32_t offset;
Ref<Base::RootConstantAccessor> root_const_accessor_ref;

PushConstantSetter(Rhi::ProgramArgumentAccessType access_type,
vk::ShaderStageFlags shader_stages, uint32_t offset,
Base::RootConstantAccessor& root_const_accessor_ref);
};

using PushConstantSetters = std::vector<PushConstantSetter>;

mutable Ptr<Rhi::IResourceBarriers> m_resource_ownership_transition_barriers_ptr;
PushConstantSetters m_push_constant_setters;
std::vector<vk::DescriptorSet> m_descriptor_sets; // descriptor sets corresponding to pipeline layout in the order of their access type
bool m_has_mutable_descriptor_set = false; // if true, then m_descriptor_sets.back() is mutable descriptor set
std::vector<uint32_t> m_dynamic_offsets; // dynamic buffer offsets for all descriptor sets from the bound ResourceView::Settings::offset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ const std::vector<vk::DescriptorSetLayout>& Program::GetNativeDescriptorSetLayou
return m_vk_descriptor_set_layouts;
}

const std::vector<vk::PushConstantRange>& Program::GetNativePushConstantRanges() const
{
META_FUNCTION_TASK();
return m_vk_push_constant_ranges;
}

const vk::DescriptorSetLayout& Program::GetNativeDescriptorSetLayout(Rhi::ProgramArgumentAccessType argument_access_type) const
{
META_FUNCTION_TASK();
Expand Down Expand Up @@ -130,10 +136,10 @@ const vk::PipelineLayout& Program::AcquireNativePipelineLayout()
if (m_vk_unique_pipeline_layout)
return m_vk_unique_pipeline_layout.get();

const std::vector<vk::DescriptorSetLayout>& vk_descriptor_set_layouts = GetNativeDescriptorSetLayouts();
const vk::Device& vk_device = GetVulkanContext().GetVulkanDevice().GetNativeDevice();
const vk::PipelineLayoutCreateInfo pipeline_layout_create_info({}, m_vk_descriptor_set_layouts, m_vk_push_constant_ranges);

m_vk_unique_pipeline_layout = vk_device.createPipelineLayoutUnique(vk::PipelineLayoutCreateInfo({}, vk_descriptor_set_layouts));
m_vk_unique_pipeline_layout = vk_device.createPipelineLayoutUnique(pipeline_layout_create_info);
UpdatePipelineName();

return m_vk_unique_pipeline_layout.get();
Expand Down Expand Up @@ -190,23 +196,38 @@ const vk::DescriptorSet& Program::AcquireFrameConstantDescriptorSet(Data::Index
void Program::InitializeDescriptorSetLayouts()
{
META_FUNCTION_TASK();
uint32_t push_constants_offset = 0U;
m_vk_push_constant_ranges.clear();
for (const auto& [program_argument, argument_binding_ptr] : GetArgumentBindings())
{
META_CHECK_NOT_NULL(argument_binding_ptr);
const auto& vulkan_argument_binding = dynamic_cast<const ProgramBindings::ArgumentBinding&>(*argument_binding_ptr);
auto& vulkan_argument_binding = dynamic_cast<ProgramArgumentBinding&>(*argument_binding_ptr);
const ProgramBindings::ArgumentBinding::Settings& vulkan_binding_settings = vulkan_argument_binding.GetVulkanSettings();
const size_t accessor_type_index = magic_enum::enum_index(vulkan_binding_settings.argument.GetAccessorType()).value();

DescriptorSetLayoutInfo& layout_info = m_descriptor_set_layout_info_by_access_type[accessor_type_index];
layout_info.descriptors_count += vulkan_binding_settings.resource_count;
layout_info.arguments.emplace_back(vulkan_binding_settings.argument);
layout_info.byte_code_maps_for_arguments.emplace_back(vulkan_binding_settings.byte_code_maps);
layout_info.bindings.emplace_back(
static_cast<uint32_t>(layout_info.bindings.size()),
vulkan_binding_settings.descriptor_type,
vulkan_binding_settings.resource_count,
Shader::ConvertTypeToStageFlagBits(program_argument.GetShaderType())
);
if (vulkan_binding_settings.argument.IsRootConstantValue())
{
vulkan_argument_binding.SetPushConstantsOffset(push_constants_offset);
m_vk_push_constant_ranges.emplace_back(
vulkan_argument_binding.GetNativeShaderStageFlags(),
push_constants_offset,
vulkan_binding_settings.buffer_size
);
push_constants_offset += vulkan_binding_settings.buffer_size;
}
else
{
layout_info.descriptors_count += vulkan_binding_settings.resource_count;
layout_info.arguments.emplace_back(vulkan_binding_settings.argument);
layout_info.byte_code_maps_for_arguments.emplace_back(vulkan_binding_settings.byte_code_maps);
layout_info.bindings.emplace_back(
static_cast<uint32_t>(layout_info.bindings.size()),
vulkan_binding_settings.descriptor_type,
vulkan_binding_settings.resource_count,
Shader::ConvertTypeToStageFlagBits(program_argument.GetShaderType())
);
}
}

#ifdef METHANE_LOGGING_ENABLED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ ProgramArgumentBinding::ProgramArgumentBinding(const Base::Context& context, con
, m_settings_vk(settings)
{ }

vk::ShaderStageFlagBits ProgramArgumentBinding::GetNativeShaderStageFlags() const
{
META_FUNCTION_TASK();
return Shader::ConvertTypeToStageFlagBits(m_settings_vk.argument.GetShaderType());
}

void ProgramArgumentBinding::SetDescriptorSetBinding(const vk::DescriptorSet& descriptor_set, uint32_t binding_value) noexcept
{
META_FUNCTION_TASK();
Expand All @@ -71,6 +77,12 @@ void ProgramArgumentBinding::SetDescriptorSet(const vk::DescriptorSet& descripto
}
}

void ProgramArgumentBinding::SetPushConstantsOffset(uint32_t push_constant_offset) noexcept
{
META_FUNCTION_TASK();
m_vk_push_constants_offset = push_constant_offset;
}

Ptr<Base::ProgramArgumentBinding> ProgramArgumentBinding::CreateCopy() const
{
META_FUNCTION_TASK();
Expand Down
Loading

0 comments on commit 2756175

Please sign in to comment.