Skip to content

Commit

Permalink
a bit of reflection related code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
floooh committed Mar 12, 2024
1 parent 6362c8f commit a05193f
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 447 deletions.
273 changes: 273 additions & 0 deletions src/shdc/reflection.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
/*
Code for reflection parsing.
*/
#include "shdc.h"

// workaround for Compiler.comparison_ids being protected
class UnprotectedCompiler: spirv_cross::Compiler {
public:
bool is_comparison_sampler(const spirv_cross::SPIRType &type, uint32_t id) {
if (type.basetype == spirv_cross::SPIRType::Sampler) {
return comparison_ids.count(id) > 0;
}
return 0;
}
bool is_used_as_depth_texture(const spirv_cross::SPIRType &type, uint32_t id) {
if (type.basetype == spirv_cross::SPIRType::Image) {
return comparison_ids.count(id) > 0;
}
return 0;
}
};

using namespace spirv_cross;

namespace shdc {

static uniform_t::type_t spirtype_to_uniform_type(const SPIRType& type) {
switch (type.basetype) {
case SPIRType::Float:
if (type.columns == 1) {
// scalar or vec
switch (type.vecsize) {
case 1: return uniform_t::FLOAT;
case 2: return uniform_t::FLOAT2;
case 3: return uniform_t::FLOAT3;
case 4: return uniform_t::FLOAT4;
}
}
else {
// a matrix
if ((type.vecsize == 4) && (type.columns == 4)) {
return uniform_t::MAT4;
}
}
break;
case SPIRType::Int:
if (type.columns == 1) {
switch (type.vecsize) {
case 1: return uniform_t::INT;
case 2: return uniform_t::INT2;
case 3: return uniform_t::INT3;
case 4: return uniform_t::INT4;
}
}
break;
default: break;
}
// fallthrough: invalid type
return uniform_t::INVALID;
}

static image_type_t::type_t spirtype_to_image_type(const SPIRType& type) {
if (type.image.arrayed) {
if (type.image.dim == spv::Dim2D) {
return image_type_t::ARRAY;
}
} else {
switch (type.image.dim) {
case spv::Dim2D: return image_type_t::_2D;
case spv::DimCube: return image_type_t::CUBE;
case spv::Dim3D: return image_type_t::_3D;
default: break;
}
}
// fallthrough: invalid type
return image_type_t::INVALID;
}

static image_sample_type_t::type_t spirtype_to_image_sample_type(const SPIRType& type) {
if (type.image.depth) {
return image_sample_type_t::DEPTH;
} else {
switch (type.basetype) {
case SPIRType::Int:
case SPIRType::Short:
case SPIRType::SByte:
return image_sample_type_t::SINT;
case SPIRType::UInt:
case SPIRType::UShort:
case SPIRType::UByte:
return image_sample_type_t::UINT;
default:
return image_sample_type_t::FLOAT;
}
}
}

static bool spirtype_to_image_multisampled(const SPIRType& type) {
return type.image.ms;
}

reflection_t reflection_t::parse(const Compiler& compiler, const snippet_t& snippet, slang_t::type_t slang) {
reflection_t refl;

ShaderResources shd_resources = compiler.get_shader_resources();
// shader stage
switch (compiler.get_execution_model()) {
case spv::ExecutionModelVertex: refl.stage = stage_t::VS; break;
case spv::ExecutionModelFragment: refl.stage = stage_t::FS; break;
default: refl.stage = stage_t::INVALID; break;
}

// find entry point
const auto entry_points = compiler.get_entry_points_and_stages();
for (const auto& item: entry_points) {
if (compiler.get_execution_model() == item.execution_model) {
refl.entry_point = item.name;
break;
}
}
// stage inputs and outputs
for (const Resource& res_attr: shd_resources.stage_inputs) {
attr_t refl_attr;
refl_attr.slot = compiler.get_decoration(res_attr.id, spv::DecorationLocation);
refl_attr.name = res_attr.name;
refl_attr.sem_name = "TEXCOORD";
refl_attr.sem_index = refl_attr.slot;
refl.inputs[refl_attr.slot] = refl_attr;
}
for (const Resource& res_attr: shd_resources.stage_outputs) {
attr_t refl_attr;
refl_attr.slot = compiler.get_decoration(res_attr.id, spv::DecorationLocation);
refl_attr.name = res_attr.name;
refl_attr.sem_name = "TEXCOORD";
refl_attr.sem_index = refl_attr.slot;
refl.outputs[refl_attr.slot] = refl_attr;
}
// uniform blocks
for (const Resource& ub_res: shd_resources.uniform_buffers) {
std::string n = compiler.get_name(ub_res.id);
uniform_block_t refl_ub;
const SPIRType& ub_type = compiler.get_type(ub_res.base_type_id);
refl_ub.slot = compiler.get_decoration(ub_res.id, spv::DecorationBinding);
refl_ub.size = (int) compiler.get_declared_struct_size(ub_type);
refl_ub.struct_name = ub_res.name;
refl_ub.inst_name = compiler.get_name(ub_res.id);
if (refl_ub.inst_name.empty()) {
refl_ub.inst_name = compiler.get_fallback_name(ub_res.id);
}
refl_ub.flattened = spirvcross_t::can_flatten_uniform_block(compiler, ub_res);
for (int m_index = 0; m_index < (int)ub_type.member_types.size(); m_index++) {
uniform_t refl_uniform;
refl_uniform.name = compiler.get_member_name(ub_res.base_type_id, m_index);
const SPIRType& m_type = compiler.get_type(ub_type.member_types[m_index]);
refl_uniform.type = spirtype_to_uniform_type(m_type);
if (m_type.array.size() > 0) {
refl_uniform.array_count = m_type.array[0];
}
refl_uniform.offset = compiler.type_struct_member_offset(ub_type, m_index);
refl_ub.uniforms.push_back(refl_uniform);
}
refl.uniform_blocks.push_back(refl_ub);
}
// (separate) images
for (const Resource& img_res: shd_resources.separate_images) {
image_t refl_img;
refl_img.slot = compiler.get_decoration(img_res.id, spv::DecorationBinding);
refl_img.name = img_res.name;
const SPIRType& img_type = compiler.get_type(img_res.type_id);
refl_img.type = spirtype_to_image_type(img_type);
if (((UnprotectedCompiler*)&compiler)->is_used_as_depth_texture(img_type, img_res.id)) {
refl_img.sample_type = image_sample_type_t::DEPTH;
} else {
refl_img.sample_type = spirtype_to_image_sample_type(compiler.get_type(img_type.image.type));
}
refl_img.multisampled = spirtype_to_image_multisampled(img_type);
refl.images.push_back(refl_img);
}
// (separate) samplers
for (const Resource& smp_res: shd_resources.separate_samplers) {
const SPIRType& smp_type = compiler.get_type(smp_res.type_id);
sampler_t refl_smp;
refl_smp.slot = compiler.get_decoration(smp_res.id, spv::DecorationBinding);
refl_smp.name = smp_res.name;
// HACK ALERT!
if (((UnprotectedCompiler*)&compiler)->is_comparison_sampler(smp_type, smp_res.id)) {
refl_smp.type = sampler_type_t::COMPARISON;
} else {
refl_smp.type = sampler_type_t::FILTERING;
}
refl.samplers.push_back(refl_smp);
}
// combined image samplers
for (auto& img_smp_res: compiler.get_combined_image_samplers()) {
image_sampler_t refl_img_smp;
refl_img_smp.slot = compiler.get_decoration(img_smp_res.combined_id, spv::DecorationBinding);
refl_img_smp.name = compiler.get_name(img_smp_res.combined_id);
refl_img_smp.image_name = compiler.get_name(img_smp_res.image_id);
refl_img_smp.sampler_name = compiler.get_name(img_smp_res.sampler_id);
refl.image_samplers.push_back(refl_img_smp);
}
// patch textures with overridden image-sample-types
for (auto& img: refl.images) {
const auto* tag = snippet.lookup_image_sample_type_tag(img.name);
if (tag) {
img.sample_type = tag->type;
}
}
// patch samplers with overridden sampler-types
for (auto& smp: refl.samplers) {
const auto* tag = snippet.lookup_sampler_type_tag(smp.name);
if (tag) {
smp.type = tag->type;
}
}
return refl;
}

const uniform_block_t* reflection_t::find_uniform_block_by_slot(int slot) const {
for (const uniform_block_t& ub: this->uniform_blocks) {
if (ub.slot == slot) {
return &ub;
}
}
return nullptr;
}

const image_t* reflection_t::find_image_by_slot(int slot) const {
for (const image_t& img: this->images) {
if (img.slot == slot) {
return &img;
}
}
return nullptr;
}

const sampler_t* reflection_t::find_sampler_by_slot(int slot) const {
for (const sampler_t& smp: this->samplers) {
if (smp.slot == slot) {
return &smp;
}
}
return nullptr;
}

const image_sampler_t* reflection_t::find_image_sampler_by_slot(int slot) const {
for (const image_sampler_t& img_smp: this->image_samplers) {
if (img_smp.slot == slot) {
return &img_smp;
}
}
return nullptr;
}

const image_t* reflection_t::find_image_by_name(const std::string& name) const {
for (const image_t& img: this->images) {
if (img.name == name) {
return &img;
}
}
return nullptr;
}

const sampler_t* reflection_t::find_sampler_by_name(const std::string& name) const {
for (const sampler_t& smp: this->samplers) {
if (smp.name == name) {
return &smp;
}
}
return nullptr;
}

} // namespace shdc
28 changes: 11 additions & 17 deletions src/shdc/shdc.h
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ struct stage_t {
}
};

struct spirvcross_refl_t {
struct reflection_t {
stage_t::type_t stage = stage_t::INVALID;
std::string entry_point;
std::array<attr_t, attr_t::NUM> inputs;
Expand All @@ -673,16 +673,15 @@ struct spirvcross_refl_t {
std::vector<image_t> images;
std::vector<sampler_t> samplers;
std::vector<image_sampler_t> image_samplers;
};

// a helper struct to transfer symbol names from SPIRVCross to Tint
struct spirvcross_wgsl_symbol_table_t {
std::map<uint32_t, std::string> input_names;
std::map<uint32_t, std::string> output_names;
std::map<uint32_t, std::string> uniform_block_struct_names;
std::map<uint32_t, std::string> uniform_block_inst_names;
std::map<uint32_t, std::string> image_names;
std::map<uint32_t, std::string> sampler_names;
static reflection_t parse(const spirv_cross::Compiler& compiler, const snippet_t& snippet, slang_t::type_t slang);

const uniform_block_t* find_uniform_block_by_slot(int slot) const;
const image_t* find_image_by_slot(int slot) const;
const image_t* find_image_by_name(const std::string& name) const;
const sampler_t* find_sampler_by_slot(int slot) const;
const sampler_t* find_sampler_by_name(const std::string& name) const;
const image_sampler_t* find_image_sampler_by_slot(int slot) const;
};

// result of a spirv-cross compilation
Expand All @@ -691,7 +690,7 @@ struct spirvcross_source_t {
int snippet_index = -1;
std::string source_code;
errmsg_t error;
spirvcross_refl_t refl;
reflection_t refl;
};

// spirv-cross wrapper
Expand All @@ -703,6 +702,7 @@ struct spirvcross_t {
std::vector<sampler_t> unique_samplers;

static spirvcross_t translate(const input_t& inp, const spirv_t& spirv, slang_t::type_t slang);
static bool can_flatten_uniform_block(const spirv_cross::Compiler& compiler, const spirv_cross::Resource& ub_res);
int find_source_by_snippet_index(int snippet_index) const;
void dump_debug(errmsg_t::msg_format_t err_fmt, slang_t::type_t slang) const;
};
Expand Down Expand Up @@ -766,12 +766,6 @@ namespace util {
int uniform_size(uniform_t::type_t type, int array_size);
int roundup(int val, int round_to);
std::string mod_prefix(const input_t& inp);
const uniform_block_t* find_uniform_block_by_slot(const spirvcross_refl_t& refl, int slot);
const image_t* find_image_by_slot(const spirvcross_refl_t& refl, int slot);
const image_t* find_image_by_name(const spirvcross_refl_t& refl, const std::string& name);
const sampler_t* find_sampler_by_slot(const spirvcross_refl_t& refl, int slot);
const sampler_t* find_sampler_by_name(const spirvcross_refl_t& refl, const std::string& name);
const image_sampler_t* find_image_sampler_by_slot(const spirvcross_refl_t& refl, int slot);
const spirvcross_source_t* find_spirvcross_source_by_shader_name(const std::string& shader_name, const input_t& inp, const spirvcross_t& spirvcross);
const bytecode_blob_t* find_bytecode_blob_by_shader_name(const std::string& shader_name, const input_t& inp, const bytecode_t& bytecode);
std::string to_camel_case(const std::string& str);
Expand Down
12 changes: 6 additions & 6 deletions src/shdc/sokol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ static void write_stage(const char* indent,
}
L("{}desc.{}.entry = \"{}\";\n", indent, stage_name, src->refl.entry_point);
for (int ub_index = 0; ub_index < uniform_block_t::NUM; ub_index++) {
const uniform_block_t* ub = find_uniform_block_by_slot(src->refl, ub_index);
const uniform_block_t* ub = src->refl.find_uniform_block_by_slot(ub_index);
if (ub) {
L("{}desc.{}.uniform_blocks[{}].size = {};\n", indent, stage_name, ub_index, roundup(ub->size, 16));
L("{}desc.{}.uniform_blocks[{}].layout = SG_UNIFORMLAYOUT_STD140;\n", indent, stage_name, ub_index);
Expand All @@ -461,7 +461,7 @@ static void write_stage(const char* indent,
}
}
for (int img_index = 0; img_index < image_t::NUM; img_index++) {
const image_t* img = find_image_by_slot(src->refl, img_index);
const image_t* img = src->refl.find_image_by_slot(img_index);
if (img) {
L("{}desc.{}.images[{}].used = true;\n", indent, stage_name, img_index);
L("{}desc.{}.images[{}].multisampled = {};\n", indent, stage_name, img_index, img->multisampled ? "true" : "false");
Expand All @@ -470,18 +470,18 @@ static void write_stage(const char* indent,
}
}
for (int smp_index = 0; smp_index < sampler_t::NUM; smp_index++) {
const sampler_t* smp = find_sampler_by_slot(src->refl, smp_index);
const sampler_t* smp = src->refl.find_sampler_by_slot(smp_index);
if (smp) {
L("{}desc.{}.samplers[{}].used = true;\n", indent, stage_name, smp_index);
L("{}desc.{}.samplers[{}].sampler_type = {};\n", indent, stage_name, smp_index, smp_type_to_sokol_type_str(smp->type));
}
}
for (int img_smp_index = 0; img_smp_index < image_sampler_t::NUM; img_smp_index++) {
const image_sampler_t* img_smp = find_image_sampler_by_slot(src->refl, img_smp_index);
const image_sampler_t* img_smp = src->refl.find_image_sampler_by_slot(img_smp_index);
if (img_smp) {
L("{}desc.{}.image_sampler_pairs[{}].used = true;\n", indent, stage_name, img_smp_index);
L("{}desc.{}.image_sampler_pairs[{}].image_slot = {};\n", indent, stage_name, img_smp_index, find_image_by_name(src->refl, img_smp->image_name)->slot);
L("{}desc.{}.image_sampler_pairs[{}].sampler_slot = {};\n", indent, stage_name, img_smp_index, find_sampler_by_name(src->refl, img_smp->sampler_name)->slot);
L("{}desc.{}.image_sampler_pairs[{}].image_slot = {};\n", indent, stage_name, img_smp_index, src->refl.find_image_by_name(img_smp->image_name)->slot);
L("{}desc.{}.image_sampler_pairs[{}].sampler_slot = {};\n", indent, stage_name, img_smp_index, src->refl.find_sampler_by_name(img_smp->sampler_name)->slot);
if (slang_t::is_glsl(slang)) {
L("{}desc.{}.image_sampler_pairs[{}].glsl_name = \"{}\";\n", indent, stage_name, img_smp_index, img_smp->name);
}
Expand Down
Loading

0 comments on commit a05193f

Please sign in to comment.