Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: MVKShaderLibrary: Handle specializtion with macros #2434

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {

MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const mvk::SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL);
const MVKCompressor<std::string> compressedMSL,
const std::vector<std::pair<uint32_t, uint32_t> >* spec_list = nullptr);

MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const void* mslCompiledCodeData,
Expand All @@ -108,7 +109,8 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
MVKShaderModule* shaderModule);
void handleCompilationError(NSError* err, const char* opDesc);
MTLFunctionConstant* getFunctionConstant(NSArray<MTLFunctionConstant*>* mtlFCs, NSUInteger mtlFCID);
void compileLibrary(const std::string& msl);
void compileLibrary(const std::string& msl,
const std::vector<std::pair<uint32_t, uint32_t> >* spec_list = nullptr);
void compressMSL(const std::string& msl);
void decompressMSL(std::string& msl);
MVKCompressor<std::string>& getCompressedMSL() { return _compressedMSL; }
Expand All @@ -117,6 +119,9 @@ class MVKShaderLibrary : public MVKBaseDeviceObject {
id<MTLLibrary> _mtlLibrary;
MVKCompressor<std::string> _compressedMSL;
mvk::SPIRVToMSLConversionResultInfo _shaderConversionResultInfo;

bool _specialized;
std::map<std::vector<std::pair<uint32_t, uint32_t> >, MVKShaderLibrary *> _spec_variants;
};


Expand Down Expand Up @@ -260,7 +265,8 @@ class MVKShaderLibraryCompiler : public MVKMetalCompiler {
* nanoseconds, an error will be generated and logged, and nil will be returned.
*/
id<MTLLibrary> newMTLLibrary(NSString* mslSourceCode,
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults);
const mvk::SPIRVToMSLConversionResultInfo& shaderConversionResults,
const std::vector<std::pair<uint32_t, uint32_t> >* spec_list = nullptr);


#pragma mark Construction
Expand Down
76 changes: 65 additions & 11 deletions MoltenVK/MoltenVK/GPUObjects/MVKShaderModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
#include "MVKShaderModule.h"
#include "MVKPipeline.h"
#include "MVKFoundation.h"
#include <Foundation/Foundation.h>
#include <cstdint>
#include <sys/stat.h>
#include <string>

using namespace std;
using namespace mvk;
Expand Down Expand Up @@ -75,12 +78,40 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

if ( !_mtlLibrary ) { return MVKMTLFunctionNull; }

id<MTLLibrary> lib = _mtlLibrary;

if (pSpecializationInfo && !_specialized) {
std::string msl;
decompressMSL(msl);
std::vector<std::pair<uint32_t, uint32_t> > spec_list;
for (uint32_t specIdx = 0; specIdx < pSpecializationInfo->mapEntryCount; specIdx++) {
std::string const_name = "SPIRV_CROSS_CONSTANT_ID_" + std::to_string(specIdx);
const VkSpecializationMapEntry* pMapEntry = &pSpecializationInfo->pMapEntries[specIdx];
uint32_t spec_val = *(uint32_t *)((char *)pSpecializationInfo->pData + pMapEntry->offset);
if (msl.find(const_name) != std::string::npos) {
spec_list.push_back(std::make_pair(specIdx, spec_val));
}
}

if (!spec_list.empty()) {
auto entry = _spec_variants.find(spec_list);
if (entry != _spec_variants.end()) {
lib = entry->second->_mtlLibrary;
} else {
MVKShaderLibrary *new_mvklib = new MVKShaderLibrary(_owner, _shaderConversionResultInfo, _compressedMSL, &spec_list);
_spec_variants[spec_list] = new_mvklib;
lib = new_mvklib->_mtlLibrary;
}
}
}


@synchronized (getMTLDevice()) {
@autoreleasepool {
NSString* mtlFuncName = @(_shaderConversionResultInfo.entryPoint.mtlFunctionName.c_str());

uint64_t startTime = pShaderFeedback ? mvkGetTimestamp() : getPerformanceTimestamp();
id<MTLFunction> mtlFunc = [[_mtlLibrary newFunctionWithName: mtlFuncName] autorelease];
id<MTLFunction> mtlFunc = [[lib newFunctionWithName: mtlFuncName] autorelease];
addPerformanceInterval(getPerformanceStats().shaderCompilation.functionRetrieval, startTime);
if (pShaderFeedback) {
if (mtlFunc) {
Expand Down Expand Up @@ -120,7 +151,7 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
if (pShaderFeedback) {
startTime = mvkGetTimestamp();
}
mtlFunc = [fs.newMTLFunction(_mtlLibrary, mtlFuncName, mtlFCVals) autorelease];
mtlFunc = [fs.newMTLFunction(lib, mtlFuncName, mtlFCVals) autorelease];
if (pShaderFeedback) {
pShaderFeedback->duration += mvkGetElapsedNanoseconds(startTime);
}
Expand Down Expand Up @@ -169,7 +200,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResult& conversionResult) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_specialized(false) {

_shaderConversionResultInfo = conversionResult.resultInfo;
compressMSL(conversionResult.msl);
Expand All @@ -178,21 +210,24 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(MVKVulkanAPIDeviceObject* owner,
const SPIRVToMSLConversionResultInfo& resultInfo,
const MVKCompressor<std::string> compressedMSL) :
const MVKCompressor<std::string> compressedMSL,
const std::vector<std::pair<uint32_t, uint32_t> >* spec_list) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_specialized(spec_list != nullptr) {

_shaderConversionResultInfo = resultInfo;
_compressedMSL = compressedMSL;
string msl;
decompressMSL(msl);
compileLibrary(msl);
compileLibrary(msl, spec_list);
}

void MVKShaderLibrary::compileLibrary(const string& msl) {
void MVKShaderLibrary::compileLibrary(const string& msl,
const std::vector<std::pair<uint32_t, uint32_t> >* spec_list) {
MVKShaderLibraryCompiler* slc = new MVKShaderLibraryCompiler(_owner);
NSString* nsSrc = [[NSString alloc] initWithUTF8String: msl.c_str()]; // temp retained
_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo); // retained
_mtlLibrary = slc->newMTLLibrary(nsSrc, _shaderConversionResultInfo, spec_list); // retained
[nsSrc release]; // release temp string
slc->destroy();
}
Expand All @@ -201,7 +236,8 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
const void* mslCompiledCodeData,
size_t mslCompiledCodeLength) :
MVKBaseDeviceObject(owner->getDevice()),
_owner(owner) {
_owner(owner),
_specialized(true) {

uint64_t startTime = getPerformanceTimestamp();
@autoreleasepool {
Expand All @@ -219,7 +255,9 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::MVKShaderLibrary(const MVKShaderLibrary& other) :
MVKBaseDeviceObject(other._device),
_owner(other._owner) {
_owner(other._owner),
_specialized(other._specialized),
_spec_variants(other._spec_variants) {

_mtlLibrary = [other._mtlLibrary retain];
_shaderConversionResultInfo = other._shaderConversionResultInfo;
Expand Down Expand Up @@ -255,6 +293,11 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD

MVKShaderLibrary::~MVKShaderLibrary() {
[_mtlLibrary release];
if (_specialized) {
for (auto& item: _spec_variants) {
delete item.second;
}
}
}


Expand Down Expand Up @@ -499,14 +542,25 @@ static uint32_t getWorkgroupDimensionSize(const SPIRVWorkgroupSizeDimension& wgD
#pragma mark MVKShaderLibraryCompiler

id<MTLLibrary> MVKShaderLibraryCompiler::newMTLLibrary(NSString* mslSourceCode,
const SPIRVToMSLConversionResultInfo& shaderConversionResults) {
const SPIRVToMSLConversionResultInfo& shaderConversionResults,
const std::vector<std::pair<uint32_t, uint32_t> >* spec_list) {
unique_lock<mutex> lock(_completionLock);

compile(lock, ^{
auto mtlDev = getMTLDevice();
@synchronized (mtlDev) {
auto mtlCompileOptions = getDevice()->getMTLCompileOptions(shaderConversionResults.entryPoint.supportsFastMath,
shaderConversionResults.isPositionInvariant);
if (spec_list != nullptr) {
NSString *macro_names[spec_list->size()];
NSNumber *macro_values[spec_list->size()];
for (uint32_t i = 0; i < spec_list->size(); i++) {
std::string const_name = "SPIRV_CROSS_CONSTANT_ID_" + std::to_string(spec_list->at(i).first);
macro_names[i] = @(const_name.c_str());
macro_values[i] = @(spec_list->at(i).second);
}
mtlCompileOptions.preprocessorMacros = [NSDictionary dictionaryWithObjects:macro_values forKeys:macro_names count:spec_list->size()];
}
MVKLogInfoIf(getMVKConfig().debugMode, "Compiling Metal shader%s.", mtlCompileOptions.fastMathEnabled ? " with FastMath enabled" : "");
[mtlDev newLibraryWithSource: mslSourceCode
options: mtlCompileOptions
Expand Down