-
Notifications
You must be signed in to change notification settings - Fork 826
Implementation of GroupSharedLimit to allow increased GroupSharedMemory #7871
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
795cf94
214553d
9b2e89d
1084671
87f85b4
08881e5
99cd065
4da0fb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -305,6 +305,20 @@ void hlsl::SetShaderProps(PSVRuntimeInfo2 *pInfo2, const DxilModule &DM) { | |
| } | ||
| } | ||
|
|
||
| void hlsl::SetShaderProps(PSVRuntimeInfo4 *pInfo4, const DxilModule &DM) { | ||
| assert(pInfo4); | ||
| const ShaderModel *SM = DM.GetShaderModel(); | ||
| switch (SM->GetKind()) { | ||
| case ShaderModel::Kind::Compute: | ||
| case ShaderModel::Kind::Mesh: | ||
| case ShaderModel::Kind::Amplification: | ||
| pInfo4->GroupSharedLimit = DM.GetGroupSharedLimit(); | ||
| break; | ||
| default: | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| void PSVResourceBindInfo0::Print(raw_ostream &OS) const { | ||
| OS << "PSVResourceBindInfo:\n"; | ||
| OS << " Space: " << Space << "\n"; | ||
|
|
@@ -584,8 +598,9 @@ void PSVDependencyTable::Print(raw_ostream &OS, const char *InputSetName, | |
|
|
||
| void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, | ||
| PSVRuntimeInfo1 *pInfo1, PSVRuntimeInfo2 *pInfo2, | ||
| PSVRuntimeInfo3 *pInfo3, uint8_t ShaderKind, | ||
| const char *EntryName, const char *Comment) { | ||
| PSVRuntimeInfo3 *pInfo3, PSVRuntimeInfo4 *pInfo4, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO in this case following the |
||
| uint8_t ShaderKind, const char *EntryName, | ||
| const char *Comment) { | ||
| if (pInfo1 && pInfo1->ShaderStage != ShaderKind) | ||
| ShaderKind = pInfo1->ShaderStage; | ||
| OS << Comment << "PSVRuntimeInfo:\n"; | ||
|
|
@@ -808,13 +823,19 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, | |
| OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," | ||
| << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; | ||
| } | ||
| if (pInfo4) { | ||
| OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n"; | ||
| } | ||
| break; | ||
| case PSVShaderKind::Amplification: | ||
| OS << Comment << " Amplification Shader\n"; | ||
| if (pInfo2) { | ||
| OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," | ||
| << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; | ||
| } | ||
| if (pInfo4) { | ||
| OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n"; | ||
| } | ||
| break; | ||
| case PSVShaderKind::Mesh: | ||
| OS << Comment << " Mesh Shader\n"; | ||
|
|
@@ -841,6 +862,9 @@ void hlsl::PrintPSVRuntimeInfo(llvm::raw_ostream &OS, PSVRuntimeInfo0 *pInfo0, | |
| OS << Comment << " NumThreads=(" << pInfo2->NumThreadsX << "," | ||
| << pInfo2->NumThreadsY << "," << pInfo2->NumThreadsZ << ")\n"; | ||
| } | ||
| if (pInfo4) { | ||
| OS << Comment << " GroupSharedLimit=" << pInfo4->GroupSharedLimit << "\n"; | ||
| } | ||
| break; | ||
| case PSVShaderKind::Library: | ||
| case PSVShaderKind::Invalid: | ||
|
|
@@ -887,9 +911,10 @@ void DxilPipelineStateValidation::PrintPSVRuntimeInfo( | |
| PSVRuntimeInfo1 *pInfo1 = m_pPSVRuntimeInfo1; | ||
| PSVRuntimeInfo2 *pInfo2 = m_pPSVRuntimeInfo2; | ||
| PSVRuntimeInfo3 *pInfo3 = m_pPSVRuntimeInfo3; | ||
| PSVRuntimeInfo4 *pInfo4 = m_pPSVRuntimeInfo4; | ||
|
|
||
| hlsl::PrintPSVRuntimeInfo( | ||
| OS, pInfo0, pInfo1, pInfo2, pInfo3, ShaderKind, | ||
| OS, pInfo0, pInfo1, pInfo2, pInfo3, pInfo4, ShaderKind, | ||
| m_pPSVRuntimeInfo3 ? m_StringTable.Get(pInfo3->EntryFunctionName) : "", | ||
| Comment); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -413,12 +413,13 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, | |||||||||||
| PSVRuntimeInfo0 *PSV0, | ||||||||||||
| PSVRuntimeInfo1 *PSV1, | ||||||||||||
| PSVRuntimeInfo2 *PSV2) { | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should have
Suggested change
|
||||||||||||
| PSVRuntimeInfo3 DMPSV; | ||||||||||||
| memset(&DMPSV, 0, sizeof(PSVRuntimeInfo3)); | ||||||||||||
| PSVRuntimeInfo4 DMPSV; | ||||||||||||
| memset(&DMPSV, 0, sizeof(PSVRuntimeInfo4)); | ||||||||||||
|
|
||||||||||||
| hlsl::SetShaderProps((PSVRuntimeInfo0 *)&DMPSV, DM); | ||||||||||||
| hlsl::SetShaderProps((PSVRuntimeInfo1 *)&DMPSV, DM); | ||||||||||||
| hlsl::SetShaderProps((PSVRuntimeInfo2 *)&DMPSV, DM); | ||||||||||||
| hlsl::SetShaderProps((PSVRuntimeInfo4 *)&DMPSV, DM); | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, part of the purpose of setting the The function doesn't take PSV3 because it is checked in For PSV4, memcmp would be no good because of the PSV3 string index indirection, so it should probably just check the property directly, which requires adding |
||||||||||||
| if (PSV1) { | ||||||||||||
| // Init things not set in InitPSVRuntimeInfo. | ||||||||||||
| DMPSV.ShaderStage = static_cast<uint8_t>(SM->GetKind()); | ||||||||||||
|
|
@@ -447,7 +448,7 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, | |||||||||||
| if (Mismatched) { | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| std::string Str; | ||||||||||||
| raw_string_ostream OS(Str); | ||||||||||||
| hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, | ||||||||||||
| hlsl::PrintPSVRuntimeInfo(OS, &DMPSV, &DMPSV, &DMPSV, &DMPSV, &DMPSV, | ||||||||||||
| static_cast<uint8_t>(SM->GetKind()), | ||||||||||||
| DM.GetEntryFunctionName().c_str(), ""); | ||||||||||||
| OS.flush(); | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3921,6 +3921,18 @@ static void ValidateGlobalVariables(ValidationContext &ValCtx) { | |
| Rule = ValidationRule::SmMaxMSSMSize; | ||
| MaxSize = DXIL::kMaxMSSMSize; | ||
| } | ||
|
|
||
| // Check if the entry function has attribute to override TGSM size. | ||
| if (M.HasDxilEntryProps(M.GetEntryFunction())) { | ||
| DxilEntryProps &EntryProps = M.GetDxilEntryProps(M.GetEntryFunction()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this code only checks non-library shaders, it seems we are missing library TGSM size validation. So, the node case is completely unchecked. However, we do calculate and store It looks to me like we have a bit of a design gap here. Wouldn't it be better to report the maximum size used (like For the compilation, we can check against the max size attribute to prevent accidental over-allocation, and for validation, we can check against the size reported in PSV0 or RDAT, as well as a version-based check (SM 6.10 required larger than default max). Then you don't need max size in PSV0 or RDAT, just reported usage. The attribute would just drive a compile-time override of default maximum.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I filed an issue on the proposal for this here: microsoft/hlsl-specs#761 |
||
| if (EntryProps.props.IsCS()) { | ||
| unsigned SpecifiedTGSMSize = EntryProps.props.groupSharedLimitBytes; | ||
| if (SpecifiedTGSMSize > 0) { | ||
| MaxSize = SpecifiedTGSMSize; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (TGSMSize > MaxSize) { | ||
| Module::global_iterator GI = M.GetModule()->global_end(); | ||
| GlobalVariable *GV = &*GI; | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1646,6 +1646,36 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { | |||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| if (const HLSLGroupSharedLimitAttr *Attr = | ||||||||||||||
| FD->getAttr<HLSLGroupSharedLimitAttr>()) { | ||||||||||||||
| if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) { | ||||||||||||||
| unsigned DiagID = Diags.getCustomDiagID( | ||||||||||||||
| DiagnosticsEngine::Error, | ||||||||||||||
| "attribute GroupSharedLimit only valid for CS/MS/AS."); | ||||||||||||||
| Diags.Report(Attr->getLocation(), DiagID); | ||||||||||||||
| return; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Only valid for SM6.10+ | ||||||||||||||
| if (!SM->IsSM610Plus()) { | ||||||||||||||
| unsigned DiagID = Diags.getCustomDiagID( | ||||||||||||||
| DiagnosticsEngine::Error, "attribute GroupSharedLimit only valid for " | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we prefer to place these diagnostics earlier, in SemaHLSL.cpp (probably under |
||||||||||||||
| "Shader Model 6.10 and above."); | ||||||||||||||
| Diags.Report(Attr->getLocation(), DiagID); | ||||||||||||||
| return; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| funcProps->groupSharedLimitBytes = Attr->getLimit(); | ||||||||||||||
| } else { | ||||||||||||||
| if (SM->IsMS()) { // Fallback to default limits | ||||||||||||||
| funcProps->groupSharedLimitBytes = DXIL::kMaxMSSMSize; // 28k For MS | ||||||||||||||
| } else if (SM->IsAS() || SM->IsCS()) { | ||||||||||||||
|
Comment on lines
+1670
to
+1672
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think relying on
Suggested change
|
||||||||||||||
| funcProps->groupSharedLimitBytes = DXIL::kMaxTGSMSize; // 32k For AS/CS | ||||||||||||||
| } else { | ||||||||||||||
| funcProps->groupSharedLimitBytes = 0; | ||||||||||||||
| } | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Hull shader. | ||||||||||||||
| if (const HLSLPatchConstantFuncAttr *Attr = | ||||||||||||||
| FD->getAttr<HLSLPatchConstantFuncAttr>()) { | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should set
m_bExtraMetadataif shader model is not 6.10 or higher. This is how validation catches newer metadata used in prior shader models where it's unsupported.