Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MVKShaderLibrary: Handle specializtion with macros (v2) #2441

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ExternalRevisions/SPIRV-Cross_repo_revision
Original file line number Diff line number Diff line change
@@ -1 +1 @@
022aad4559f4c153f44799b682ce52c92a48fd33
9f9782ed500d57d3801445ba628560177f22a117
11 changes: 10 additions & 1 deletion MoltenVK/MoltenVK/GPUObjects/MVKPipeline.mm
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#if MVK_USE_CEREAL
#include <cereal/archives/binary.hpp>
#include <cereal/types/map.hpp>
#include <cereal/types/string.hpp>
#include <cereal/types/vector.hpp>
#endif
Expand Down Expand Up @@ -2783,7 +2784,15 @@ void serialize(Archive & archive, SPIRVToMSLConversionResultInfo& scr) {
scr.needsInputThreadgroupMem,
scr.needsDispatchBaseBuffer,
scr.needsViewRangeBuffer,
scr.usesPhysicalStorageBufferAddressesCapability);
scr.usesPhysicalStorageBufferAddressesCapability,
scr.specializationMacros);
}

template<class Archive>
void serialize(Archive & archive, MSLSpecializationMacroInfo& info) {
archive(info.name,
info.isFloat,
info.isSigned);
}

}
Expand Down
54 changes: 49 additions & 5 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,38 @@ typedef struct MVKMTLFunction {
/** A MVKMTLFunction indicating an invalid MTLFunction. The mtlFunction member is nil. */
const MVKMTLFunction MVKMTLFunctionNull(nil, mvk::SPIRVToMSLConversionResultInfo(), MTLSizeMake(1, 1, 1));

/** Wraps a single MTLLibrary. */
typedef struct MVKShaderMacroValue {
union {
int8_t si8;
uint8_t ui8;
int16_t si16;
uint16_t ui16;
int32_t si32;
uint32_t ui32;
int64_t si64;
uint64_t ui64;
float f32;
double f64;
} value;
size_t size;

inline bool operator<(const MVKShaderMacroValue& other) const {
return value.ui64 < other.value.ui64 ||
(value.ui64 == other.value.ui64 && size < other.size);
}
} MVKShaderMacroValue;

/**
* Wraps a single MTLLibrary or a set of MTLLibrary variants with macro-based specialization
*
* The latter case is used when Vulkan specialization constants cannot be realized with
* Metal function constants. Those specialization constants are turned into macros, and
* when specialized, we have to *recompile* the MTLLibrary from source.
*
* To keep the details transparent to users, when specialization on macro occurs,
* MVKShaderLibrary creates specialized variants (each one also a MVKShaderLibrary) behind
* the scene and cache them in a map according to the macro-value mapping.
*/
class MVKShaderLibrary : public MVKBaseDeviceObject {

public:
Expand Down Expand Up @@ -84,9 +115,14 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const mvk::SPIRVToMSLConversionResult& conversionResult);

/**
* When specializationMacroDef is not null, creates a macro-specialized library
* specializationMacroDef contains (specialization id, value) mappings, should be sorted
*/
MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const mvk::SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL);
const MVKCompressor<std::string> compressedMSL,
const std::vector<std::pair<uint32_t, MVKShaderMacroValue>>* specializationMacroDef = nullptr);

MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const void* mslCompiledCodeData,
Expand All @@ -108,15 +144,21 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
MVKShaderModule* shaderModule);
void handleCompilationError(NSError* err, const char* opDesc);
MTLFunctionConstant* getFunctionConstant(NSArray<MTLFunctionConstant*>* mtlFCs, NSUInteger mtlFCID);
void compileLibrary(const std::string& msl);
void compileLibrary(const std::string& msl,
const std::vector<std::pair<uint32_t, MVKShaderMacroValue> >* specializationMacroDef = nullptr);
void compressMSL(const std::string& msl);
void decompressMSL(std::string& msl);
MVKCompressor<std::string>& getCompressedMSL() { return _compressedMSL; }

MVKVulkanAPIDeviceObject* _owner;
id<MTLLibrary> _mtlLibrary;
MVKCompressor<std::string> _compressedMSL;
mvk::SPIRVToMSLConversionResultInfo _shaderConversionResultInfo;
mvk::SPIRVToMSLConversionResultInfo _shaderConversionResultInfo;

/** When true, representing a library created with source, but never specialized */
bool _maySpecializeWithMacro;
/** Can only be populated when _maySpecializeWithMacro is true */
std::map<std::vector<std::pair<uint32_t, MVKShaderMacroValue>>, MVKShaderLibrary *> _specializationVariants;
};


Expand Down Expand Up @@ -260,7 +302,8 @@ class MVKShaderLibraryCompiler : public MVKMetalCompiler {
* nanoseconds, an error will be generated and logged, and nil will be returned.
*/
id<MTLLibrary> newMTLLibrary(NSString* mslSourceCode,
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults);
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults,
const std::vector<std::pair<mvk::MSLSpecializationMacroInfo, MVKShaderMacroValue>>& macroDef);


#pragma mark Construction
Expand All @@ -273,6 +316,7 @@ class MVKShaderLibraryCompiler : public MVKMetalCompiler {
~MVKShaderLibraryCompiler() override;

protected:
NSNumber *getMacroValue(const mvk::MSLSpecializationMacroInfo& info, const MVKShaderMacroValue& value);
bool compileComplete(id<MTLLibrary> mtlLibrary, NSError *error);
void handleError() override;

Expand Down
166 changes: 146 additions & 20 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,47 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }

id<MTLLibrary> lib = _mtlLibrary;

// If specialization happens on constants mapped to macro, find or compile a library variant
// with proper macro definition instead of the "generic" library
if (pSpecializationInfo && _maySpecializeWithMacro) {
// Create the list of macro-value mapping
vector<pair<uint32_t, MVKShaderMacroValue>> spec_list;
for (uint32_t specIdx = 0; specIdx < pSpecializationInfo->mapEntryCount; specIdx++) {
const VkSpecializationMapEntry* pMapEntry = &pSpecializationInfo->pMapEntries[specIdx];
uint32_t const_id = pMapEntry->constantID;
MVKShaderMacroValue macro_value = {};
size_t size = min(pMapEntry->size, sizeof(macro_value.value));

memcpy(&macro_value.value, (char *)pSpecializationInfo->pData + pMapEntry->offset, size);
macro_value.size = size;
if (_shaderConversionResultInfo.specializationMacros.find(const_id) != _shaderConversionResultInfo.specializationMacros.end()) {
spec_list.push_back(make_pair(const_id, macro_value));
}
}

if (!spec_list.empty()) {
// Sort the specialization list before it is used as a key to index the variants
std::sort(spec_list.begin(), spec_list.end());
auto entry = _specializationVariants.find(spec_list);
if (entry != _specializationVariants.end()) {
lib = entry->second->_mtlLibrary;
} else {
MVKShaderLibrary *new_mvklib = new MVKShaderLibrary(_owner, _shaderConversionResultInfo, _compressedMSL, &spec_list);
_specializationVariants[spec_list] = new_mvklib;
lib = new_mvklib->_mtlLibrary;
}
}
}


@synchronized (getMTLDevice()) {
@autoreleasepool {
NSString* mtlFuncName = @(_shaderConversionResultInfo.entryPoint.mtlFunctionName.c_str());

uint64_t startTime = pShaderFeedback ? mvkGetTimestamp() : getPerformanceTimestamp();
id<MTLFunction> mtlFunc = [[_mtlLibrary newFunctionWithName: mtlFuncName] autorelease];
id<MTLFunction> mtlFunc = [[lib newFunctionWithName: mtlFuncName] autorelease];
addPerformanceInterval(getPerformanceStats().shaderCompilation.functionRetrieval, startTime);
if (pShaderFeedback) {
if (mtlFunc) {
Expand Down Expand Up @@ -120,7 +155,7 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
if (pShaderFeedback) {
startTime = mvkGetTimestamp();
}
mtlFunc = [fs.newMTLFunction(_mtlLibrary, mtlFuncName, mtlFCVals) autorelease];
mtlFunc = [fs.newMTLFunction(lib, mtlFuncName, mtlFCVals) autorelease];
if (pShaderFeedback) {
pShaderFeedback->duration += mvkGetElapsedNanoseconds(startTime);
}
Expand Down Expand Up @@ -169,7 +204,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResult& conversionResult) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(true) {

_shaderConversionResultInfo = conversionResult.resultInfo;
compressMSL(conversionResult.msl);
Expand All @@ -178,21 +214,36 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL) :
const MVKCompressor<std::string> compressedMSL,
const vector<pair<uint32_t, MVKShaderMacroValue> >* specializationMacroDef) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(specializationMacroDef == nullptr) {

_shaderConversionResultInfo = resultInfo;
_compressedMSL = compressedMSL;
string msl;
decompressMSL(msl);
compileLibrary(msl);
compileLibrary(msl, specializationMacroDef);
}

void MVKShaderLibrary::compileLibrary(const string& msl) {
void MVKShaderLibrary::compileLibrary(const string& msl,
const vector<pair<uint32_t, MVKShaderMacroValue> >* specializationMacroDef) {
MVKShaderLibraryCompiler* slc = new MVKShaderLibraryCompiler(_owner);
NSString* nsSrc = [[NSString alloc] initWithUTF8String: msl.c_str()]; // temp retained
_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo); // retained

// If specialization macro is used, translate the id to macro information and pass it to compiler
vector<pair<MSLSpecializationMacroInfo, MVKShaderMacroValue>> macro_def;
if (specializationMacroDef) {
for (auto& def: *specializationMacroDef) {
const auto& macro_name_iter = _shaderConversionResultInfo.specializationMacros.find(def.first);
if (macro_name_iter != _shaderConversionResultInfo.specializationMacros.end()) {
macro_def.push_back(make_pair(macro_name_iter->second, def.second));
}
}
}

_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo, macro_def); // retained
[nsSrc release]; // release temp string
slc->destroy();
}
Expand All @@ -201,7 +252,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
const void* mslCompiledCodeData,
size_t mslCompiledCodeLength) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_maySpecializeWithMacro(false) {

uint64_t startTime = getPerformanceTimestamp();
@autoreleasepool {
Expand All @@ -219,7 +271,9 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(const MVKShaderLibrary& other) :
MVKBaseDeviceObject(other._device),
_owner(other._owner) {
_owner(other._owner),
_maySpecializeWithMacro(other._maySpecializeWithMacro),
_specializationVariants(other._specializationVariants) {

_mtlLibrary = [other._mtlLibrary retain];
_shaderConversionResultInfo = other._shaderConversionResultInfo;
Expand Down Expand Up @@ -255,6 +309,10 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::~MVKShaderLibrary() {
[_mtlLibrary release];

for (auto& item: _specializationVariants) {
delete item.second;
}
}


Expand Down Expand Up @@ -499,27 +557,95 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
#pragma mark MVKShaderLibraryCompiler

id<MTLLibrary> MVKShaderLibraryCompiler::newMTLLibrary(NSString* mslSourceCode,
const SPIRVToMSLConversionResultInfo& shaderConversionResults) {
const SPIRVToMSLConversionResultInfo& shaderConversionResults,
const vector<pair<MSLSpecializationMacroInfo, MVKShaderMacroValue>>& specializationMacroDef) {
unique_lock<mutex> lock(_completionLock);

compile(lock, ^{
auto mtlDev = getMTLDevice();
@synchronized (mtlDev) {
auto mtlCompileOptions = getDevice()->getMTLCompileOptions(shaderConversionResults.entryPoint.supportsFastMath,
shaderConversionResults.isPositionInvariant);
MVKLogInfoIf(getMVKConfig().debugMode, "Compiling Metal shader%s.", mtlCompileOptions.fastMathEnabled ? " with FastMath enabled" : "");
[mtlDev newLibraryWithSource: mslSourceCode
options: mtlCompileOptions
completionHandler: ^(id<MTLLibrary> mtlLib, NSError* error) {
bool isLate = compileComplete(mtlLib, error);
if (isLate) { destroy(); }
}];
@autoreleasepool {
auto mtlCompileOptions = getDevice()->getMTLCompileOptions(shaderConversionResults.entryPoint.supportsFastMath,
shaderConversionResults.isPositionInvariant);
if (!specializationMacroDef.empty()) {
size_t macro_count = specializationMacroDef.size();
NSString *macro_names[macro_count];
NSNumber *macro_values[macro_count];
for (uint32_t i = 0; i < specializationMacroDef.size(); i++) {
macro_names[i] = @(specializationMacroDef[i].first.name.c_str());
macro_values[i] = getMacroValue(specializationMacroDef[i].first, specializationMacroDef[i].second);
}
mtlCompileOptions.preprocessorMacros = [NSDictionary dictionaryWithObjects: macro_values
forKeys: macro_names
count: macro_count];
}
MVKLogInfoIf(getMVKConfig().debugMode, "Compiling Metal shader%s.", mtlCompileOptions.fastMathEnabled ? " with FastMath enabled" : "");
[mtlDev newLibraryWithSource: mslSourceCode
options: mtlCompileOptions
completionHandler: ^(id<MTLLibrary> mtlLib, NSError* error) {
bool isLate = compileComplete(mtlLib, error);
if (isLate) { destroy(); }
}];
}
}
});

return [_mtlLibrary retain];
}

NSNumber *MVKShaderLibraryCompiler::getMacroValue(const MSLSpecializationMacroInfo& info,
const MVKShaderMacroValue& value) {
NSNumber *result;

if (info.isFloat) {
if (value.size == sizeof(double)) {
result = [NSNumber numberWithDouble: value.value.f64];
} else {
result = [NSNumber numberWithFloat: value.value.f32];
}
} else {
if (info.isSigned) {
switch (value.size) {
case 1:
result = [NSNumber numberWithChar: value.value.si8];
break;
case 2:
result = [NSNumber numberWithShort: value.value.si16];
break;
case 4:
result = [NSNumber numberWithInt: value.value.si32];
break;
case 8:
result = [NSNumber numberWithLongLong: value.value.si64];
break;
default:
result = [NSNumber numberWithInt: value.value.si32];
break;
}
} else {
switch (value.size) {
case 1:
result = [NSNumber numberWithUnsignedChar: value.value.ui8];
break;
case 2:
result = [NSNumber numberWithUnsignedShort: value.value.ui16];
break;
case 4:
result = [NSNumber numberWithUnsignedInt: value.value.ui32];
break;
case 8:
result = [NSNumber numberWithUnsignedLongLong: value.value.ui64];
break;
default:
result = [NSNumber numberWithUnsignedInt: value.value.ui32];
break;
}
}
}

return result;
}

void MVKShaderLibraryCompiler::handleError() {
if (_mtlLibrary) {
MVKLogInfo("%s compilation succeeded with warnings (Error code %li):\n%s", _compilerType.c_str(),
Expand Down
Loading