@@ -325,6 +325,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
325325};
326326} // namespace
327327
328+ static uint32_t GetIntConstAttrArg (ASTContext &astContext, const Expr *expr,
329+ uint32_t defaultVal = 0 ) {
330+ if (expr) {
331+ llvm::APSInt apsInt;
332+ APValue apValue;
333+ if (expr->isIntegerConstantExpr (apsInt, astContext))
334+ return (uint32_t )apsInt.getSExtValue ();
335+ if (expr->isVulkanSpecConstantExpr (astContext, &apValue) && apValue.isInt ())
336+ return (uint32_t )apValue.getInt ().getSExtValue ();
337+ }
338+ return defaultVal;
339+ }
340+
328341// ------------------------------------------------------------------------------
329342//
330343// CGMSHLSLRuntime methods.
@@ -1419,6 +1432,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14191432 }
14201433
14211434 DiagnosticsEngine &Diags = CGM.getDiags ();
1435+ ASTContext &astContext = CGM.getTypes ().getContext ();
14221436
14231437 std::unique_ptr<DxilFunctionProps> funcProps =
14241438 llvm::make_unique<DxilFunctionProps>();
@@ -1629,10 +1643,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16291643
16301644 // Populate numThreads
16311645 if (const HLSLNumThreadsAttr *Attr = FD->getAttr <HLSLNumThreadsAttr>()) {
1632-
1633- funcProps->numThreads [0 ] = Attr->getX ();
1634- funcProps->numThreads [1 ] = Attr->getY ();
1635- funcProps->numThreads [2 ] = Attr->getZ ();
1646+ funcProps->numThreads [0 ] = GetIntConstAttrArg (astContext, Attr->getX (), 1 );
1647+ funcProps->numThreads [1 ] = GetIntConstAttrArg (astContext, Attr->getY (), 1 );
1648+ funcProps->numThreads [2 ] = GetIntConstAttrArg (astContext, Attr->getZ (), 1 );
16361649
16371650 if (isEntry && !SM->IsCS () && !SM->IsMS () && !SM->IsAS ()) {
16381651 unsigned DiagID = Diags.getCustomDiagID (
@@ -1805,7 +1818,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18051818
18061819 if (const auto *pAttr = FD->getAttr <HLSLNodeIdAttr>()) {
18071820 funcProps->NodeShaderID .Name = pAttr->getName ().str ();
1808- funcProps->NodeShaderID .Index = pAttr->getArrayIndex ();
1821+ funcProps->NodeShaderID .Index =
1822+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18091823 } else {
18101824 funcProps->NodeShaderID .Name = FD->getName ().str ();
18111825 funcProps->NodeShaderID .Index = 0 ;
@@ -1816,20 +1830,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18161830 }
18171831 if (const auto *pAttr = FD->getAttr <HLSLNodeShareInputOfAttr>()) {
18181832 funcProps->NodeShaderSharedInput .Name = pAttr->getName ().str ();
1819- funcProps->NodeShaderSharedInput .Index = pAttr->getArrayIndex ();
1833+ funcProps->NodeShaderSharedInput .Index =
1834+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18201835 }
18211836 if (const auto *pAttr = FD->getAttr <HLSLNodeDispatchGridAttr>()) {
1822- funcProps->Node .DispatchGrid [0 ] = pAttr->getX ();
1823- funcProps->Node .DispatchGrid [1 ] = pAttr->getY ();
1824- funcProps->Node .DispatchGrid [2 ] = pAttr->getZ ();
1837+ funcProps->Node .DispatchGrid [0 ] =
1838+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1839+ funcProps->Node .DispatchGrid [1 ] =
1840+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1841+ funcProps->Node .DispatchGrid [2 ] =
1842+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
18251843 }
18261844 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxDispatchGridAttr>()) {
1827- funcProps->Node .MaxDispatchGrid [0 ] = pAttr->getX ();
1828- funcProps->Node .MaxDispatchGrid [1 ] = pAttr->getY ();
1829- funcProps->Node .MaxDispatchGrid [2 ] = pAttr->getZ ();
1845+ funcProps->Node .MaxDispatchGrid [0 ] =
1846+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1847+ funcProps->Node .MaxDispatchGrid [1 ] =
1848+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1849+ funcProps->Node .MaxDispatchGrid [2 ] =
1850+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
18301851 }
18311852 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxRecursionDepthAttr>()) {
1832- funcProps->Node .MaxRecursionDepth = pAttr->getCount ();
1853+ funcProps->Node .MaxRecursionDepth =
1854+ GetIntConstAttrArg (astContext, pAttr->getCount (), 0 );
18331855 }
18341856 if (!FD->getAttr <HLSLNumThreadsAttr>()) {
18351857 // NumThreads wasn't specified.
@@ -2343,8 +2365,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23432365 NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23442366
23452367 if (parmDecl->hasAttr <HLSLMaxRecordsAttr>()) {
2346- node.MaxRecords =
2347- parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ();
2368+ node.MaxRecords = GetIntConstAttrArg (
2369+ astContext,
2370+ parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount (), 1 );
23482371 }
23492372 if (parmDecl->hasAttr <HLSLGloballyCoherentAttr>())
23502373 node.Flags .SetGloballyCoherent ();
@@ -2375,7 +2398,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23752398 // OutputID from attribute
23762399 if (const auto *Attr = parmDecl->getAttr <HLSLNodeIdAttr>()) {
23772400 node.OutputID .Name = Attr->getName ().str ();
2378- node.OutputID .Index = Attr->getArrayIndex ();
2401+ node.OutputID .Index =
2402+ GetIntConstAttrArg (astContext, Attr->getArrayIndex (), 0 );
23792403 } else {
23802404 node.OutputID .Name = parmDecl->getName ().str ();
23812405 node.OutputID .Index = 0 ;
@@ -2434,7 +2458,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24342458 node.MaxRecordsSharedWith = ix;
24352459 }
24362460 if (const auto *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>())
2437- node.MaxRecords = Attr->getMaxCount ();
2461+ node.MaxRecords = GetIntConstAttrArg (astContext, Attr->getMaxCount (), 0 );
24382462 }
24392463
24402464 if (inputPatchCount > 1 ) {
0 commit comments