Skip to content

[clang][NFC] Move more functions to SemaHLSL #88354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -2940,13 +2940,6 @@ class Sema final : public SemaBase {
QualType NewT, QualType OldT);
void CheckMain(FunctionDecl *FD, const DeclSpec &D);
void CheckMSVCRTEntryPoint(FunctionDecl *FD);
void ActOnHLSLTopLevelFunction(FunctionDecl *FD);
void CheckHLSLEntryPoint(FunctionDecl *FD);
void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseHLSLAttrStageMismatch(
const Attr *A, HLSLShaderAttr::ShaderType Stage,
std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD,
bool IsDefinition);
void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D);
Expand Down Expand Up @@ -3707,14 +3700,6 @@ class Sema final : public SemaBase {
StringRef UuidAsWritten, MSGuidDecl *GuidDecl);

BTFDeclTagAttr *mergeBTFDeclTagAttr(Decl *D, const BTFDeclTagAttr &AL);
HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D,
const AttributeCommonInfo &AL,
int X, int Y, int Z);
HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLShaderAttr::ShaderType ShaderType);
HLSLParamModifierAttr *
mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling);

WebAssemblyImportNameAttr *
mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
Expand Down
27 changes: 23 additions & 4 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,42 @@
#ifndef LLVM_CLANG_SEMA_SEMAHLSL_H
#define LLVM_CLANG_SEMA_SEMAHLSL_H

#include "clang/AST/Attr.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/Expr.h"
#include "clang/Basic/AttributeCommonInfo.h"
#include "clang/Basic/IdentifierTable.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Sema/Scope.h"
#include "clang/Sema/SemaBase.h"
#include <initializer_list>

namespace clang {

class SemaHLSL : public SemaBase {
public:
SemaHLSL(Sema &S);

Decl *ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer,
SourceLocation KwLoc, IdentifierInfo *Ident,
SourceLocation IdentLoc, SourceLocation LBrace);
void ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace);
Decl *ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc,
IdentifierInfo *Ident, SourceLocation IdentLoc,
SourceLocation LBrace);
void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace);
HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D,
const AttributeCommonInfo &AL, int X,
int Y, int Z);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLShaderAttr::ShaderType ShaderType);
HLSLParamModifierAttr *
mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling);
void ActOnTopLevelFunction(FunctionDecl *FD);
void CheckEntryPoint(FunctionDecl *FD);
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
const Attr *A, HLSLShaderAttr::ShaderType Stage,
std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages);
};

} // namespace clang
Expand Down
10 changes: 5 additions & 5 deletions clang/lib/Parse/ParseHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
return nullptr;
}

Decl *D = Actions.HLSL().ActOnStartHLSLBuffer(
getCurScope(), IsCBuffer, BufferLoc, Identifier, IdentifierLoc,
T.getOpenLocation());
Decl *D = Actions.HLSL().ActOnStartBuffer(getCurScope(), IsCBuffer, BufferLoc,
Identifier, IdentifierLoc,
T.getOpenLocation());

while (Tok.isNot(tok::r_brace) && Tok.isNot(tok::eof)) {
// FIXME: support attribute on constants inside cbuffer/tbuffer.
Expand All @@ -88,15 +88,15 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) {
T.skipToEnd();
DeclEnd = T.getCloseLocation();
BufferScope.Exit();
Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);
return nullptr;
}
}

T.consumeClose();
DeclEnd = T.getCloseLocation();
BufferScope.Exit();
Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd);
Actions.HLSL().ActOnFinishBuffer(D, DeclEnd);

Actions.ProcessDeclAttributeList(Actions.CurScope, D, Attrs);
return D;
Expand Down
130 changes: 6 additions & 124 deletions clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "clang/Sema/ParsedTemplate.h"
#include "clang/Sema/Scope.h"
#include "clang/Sema/ScopeInfo.h"
#include "clang/Sema/SemaHLSL.h"
#include "clang/Sema/SemaInternal.h"
#include "clang/Sema/Template.h"
#include "llvm/ADT/SmallString.h"
Expand Down Expand Up @@ -2972,10 +2973,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
else if (const auto *BTFA = dyn_cast<BTFDeclTagAttr>(Attr))
NewAttr = S.mergeBTFDeclTagAttr(D, *BTFA);
else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr))
NewAttr =
S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ());
NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(),
NT->getZ());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType());
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
// Do nothing. Each redeclaration should be suppressed separately.
NewAttr = nullptr;
Expand Down Expand Up @@ -10809,10 +10810,10 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
if (getLangOpts().HLSL && D.isFunctionDefinition()) {
// Any top level function could potentially be specified as an entry.
if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier())
ActOnHLSLTopLevelFunction(NewFD);
HLSL().ActOnTopLevelFunction(NewFD);

if (NewFD->hasAttr<HLSLShaderAttr>())
CheckHLSLEntryPoint(NewFD);
HLSL().CheckEntryPoint(NewFD);
}

// If this is the first declaration of a library builtin function, add
Expand Down Expand Up @@ -12660,125 +12661,6 @@ void Sema::CheckMSVCRTEntryPoint(FunctionDecl *FD) {
}
}

void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) {
auto &TargetInfo = getASTContext().getTargetInfo();

if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry)
return;

StringRef Env = TargetInfo.getTriple().getEnvironmentName();
HLSLShaderAttr::ShaderType ShaderType;
if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) {
if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) {
// The entry point is already annotated - check that it matches the
// triple.
if (Shader->getType() != ShaderType) {
Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch)
<< Shader;
FD->setInvalidDecl();
}
} else {
// Implicitly add the shader attribute if the entry function isn't
// explicitly annotated.
FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType,
FD->getBeginLoc()));
}
} else {
switch (TargetInfo.getTriple().getEnvironment()) {
case llvm::Triple::UnknownEnvironment:
case llvm::Triple::Library:
break;
default:
llvm_unreachable("Unhandled environment in triple");
}
}
}

void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) {
const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();

switch (ST) {
case HLSLShaderAttr::Pixel:
case HLSLShaderAttr::Vertex:
case HLSLShaderAttr::Geometry:
case HLSLShaderAttr::Hull:
case HLSLShaderAttr::Domain:
case HLSLShaderAttr::RayGeneration:
case HLSLShaderAttr::Intersection:
case HLSLShaderAttr::AnyHit:
case HLSLShaderAttr::ClosestHit:
case HLSLShaderAttr::Miss:
case HLSLShaderAttr::Callable:
if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) {
DiagnoseHLSLAttrStageMismatch(NT, ST,
{HLSLShaderAttr::Compute,
HLSLShaderAttr::Amplification,
HLSLShaderAttr::Mesh});
FD->setInvalidDecl();
}
break;

case HLSLShaderAttr::Compute:
case HLSLShaderAttr::Amplification:
case HLSLShaderAttr::Mesh:
if (!FD->hasAttr<HLSLNumThreadsAttr>()) {
Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads)
<< HLSLShaderAttr::ConvertShaderTypeToStr(ST);
FD->setInvalidDecl();
}
break;
}

for (ParmVarDecl *Param : FD->parameters()) {
if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) {
CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr);
} else {
// FIXME: Handle struct parameters where annotations are on struct fields.
// See: https://github.com/llvm/llvm-project/issues/57875
Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation);
Diag(Param->getLocation(), diag::note_previous_decl) << Param;
FD->setInvalidDecl();
}
}
// FIXME: Verify return type semantic annotation.
}

void Sema::CheckHLSLSemanticAnnotation(
FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr) {
auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>();
assert(ShaderAttr && "Entry point has no shader attribute");
HLSLShaderAttr::ShaderType ST = ShaderAttr->getType();

switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
if (ST == HLSLShaderAttr::Compute)
return;
DiagnoseHLSLAttrStageMismatch(AnnotationAttr, ST,
{HLSLShaderAttr::Compute});
break;
default:
llvm_unreachable("Unknown HLSLAnnotationAttr");
}
}

void Sema::DiagnoseHLSLAttrStageMismatch(
const Attr *A, HLSLShaderAttr::ShaderType Stage,
std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) {
SmallVector<StringRef, 8> StageStrings;
llvm::transform(AllowedStages, std::back_inserter(StageStrings),
[](HLSLShaderAttr::ShaderType ST) {
return StringRef(
HLSLShaderAttr::ConvertShaderTypeToStr(ST));
});
Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage)
<< A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage)
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
}

bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) {
// FIXME: Need strict checking. In C89, we need to check for
// any assignment, increment, decrement, function-calls, or
Expand Down
54 changes: 4 additions & 50 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "clang/Sema/ParsedAttr.h"
#include "clang/Sema/Scope.h"
#include "clang/Sema/ScopeInfo.h"
#include "clang/Sema/SemaHLSL.h"
#include "clang/Sema/SemaInternal.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
Expand Down Expand Up @@ -7238,24 +7239,11 @@ static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
return;
}

HLSLNumThreadsAttr *NewAttr = S.mergeHLSLNumThreadsAttr(D, AL, X, Y, Z);
HLSLNumThreadsAttr *NewAttr = S.HLSL().mergeNumThreadsAttr(D, AL, X, Y, Z);
if (NewAttr)
D->addAttr(NewAttr);
}

HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D,
const AttributeCommonInfo &AL,
int X, int Y, int Z) {
if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) {
if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) {
Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
Diag(AL.getLoc(), diag::note_conflicting_attribute);
}
return nullptr;
}
return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z);
}

static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
if (!T->hasUnsignedIntegerRepresentation())
return false;
Expand Down Expand Up @@ -7299,24 +7287,11 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {

// FIXME: check function match the shader stage.

HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, ShaderType);
HLSLShaderAttr *NewAttr = S.HLSL().mergeShaderAttr(D, AL, ShaderType);
if (NewAttr)
D->addAttr(NewAttr);
}

HLSLShaderAttr *
Sema::mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLShaderAttr::ShaderType ShaderType) {
if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) {
if (NT->getType() != ShaderType) {
Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
Diag(AL.getLoc(), diag::note_conflicting_attribute);
}
return nullptr;
}
return HLSLShaderAttr::Create(Context, ShaderType, AL);
}

static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
const ParsedAttr &AL) {
StringRef Space = "space0";
Expand Down Expand Up @@ -7391,34 +7366,13 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,

static void handleHLSLParamModifierAttr(Sema &S, Decl *D,
const ParsedAttr &AL) {
HLSLParamModifierAttr *NewAttr = S.mergeHLSLParamModifierAttr(
HLSLParamModifierAttr *NewAttr = S.HLSL().mergeParamModifierAttr(
D, AL,
static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
if (NewAttr)
D->addAttr(NewAttr);
}

HLSLParamModifierAttr *
Sema::mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
HLSLParamModifierAttr::Spelling Spelling) {
// We can only merge an `in` attribute with an `out` attribute. All other
// combinations of duplicated attributes are ill-formed.
if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
(PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
D->dropAttr<HLSLParamModifierAttr>();
SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
return HLSLParamModifierAttr::Create(
Context, /*MergedSpelling=*/true, AdjustedRange,
HLSLParamModifierAttr::Keyword_inout);
}
Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
Diag(PA->getLocation(), diag::note_conflicting_attribute);
return nullptr;
}
return HLSLParamModifierAttr::Create(Context, AL);
}

static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
if (!S.LangOpts.CPlusPlus) {
S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
Expand Down
Loading