Skip to content

Use combined-check for type related extension and capability requirements #68033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 102 additions & 24 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ using namespace mlir;
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
const ArrayRef<spirv::Extension> elidedCandidates = {}) {
for (const auto &ors : candidates) {
if (targetEnv.allows(ors))
if (targetEnv.allows(ors) ||
llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
return llvm::is_contained(ors, elidedExt);
}))
continue;

LLVM_DEBUG({
Expand All @@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
const ArrayRef<spirv::Capability> elidedCandidates = {}) {
for (const auto &ors : candidates) {
if (targetEnv.allows(ors))
if (targetEnv.allows(ors) ||
llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
return llvm::is_contained(ors, elidedCap);
}))
continue;

LLVM_DEBUG({
Expand All @@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
return success();
}

/// Returns true if the given `storageClass` needs explicit layout when used in
/// Shader environments.
/// Check capabilities and extensions requirements
/// Checks that `capCandidates`, `extCandidates`, and capability
/// (`capCandidates`) infered extension requirements are possible to be
/// satisfied with the given `targetEnv`.
/// It also provides a way to relax requirements for certain capabilities and
/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
/// allow passes to relax certain requirements based on an option (e.g.,
/// relaxing bitwidth requirement, see `convertScalarType()`,
/// `ConvertVectorType()`).
template <typename LabelT>
static LogicalResult checkCapabilityAndExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
const ArrayRef<spirv::Capability> elidedCapCandidates = {},
const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
llvm::append_range(updatedExtCandidates, extCandidates);

if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
elidedCapCandidates)))
return failure();
// Add capablity infered extensions to the list of extension requirement list,
// only considers the capabilities that already available in the `targetEnv`.

// WARNING: Some capabilities are part of both the core SPIR-V
// specification and an extension (e.g., 'Groups' capability is part of both
// core specification and SPV_AMD_shader_ballot extension, hence we should
// relax the capability inferred extension for these cases).
static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
std::size(multiModalCaps));

for (auto cap : targetEnv.getAttr().getCapabilities()) {
if (llvm::any_of(multiModalCapsArrayRef,
[&cap](spirv::Capability mMCap) { return cap == mMCap; }))
continue;
std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
if (ext)
updatedExtCandidates.push_back(*ext);
}
if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
elidedExtCandidates)))
return failure();
return success();
}

/// Returns true if the given `storageClass` needs explicit layout when used
/// in Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
Expand Down Expand Up @@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
type.getCapabilities(capabilities, storageClass);

// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
if (succeeded(checkCapabilityAndExtensionRequirements(
type, targetEnv, capabilities, extensions)))
return type;

// Otherwise we need to adjust the type, which really means adjusting the
Expand Down Expand Up @@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);

// If the bit-width related capabilities and extensions are not met
// for lower bit-width (<32-bit), convert it to 32-bit
auto elementType =
convertScalarType(targetEnv, options, scalarType, storageClass);
if (!elementType)
return nullptr;
type = VectorType::get(type.getShape(), elementType);

SmallVector<spirv::Capability, 4> elidedCaps;
SmallVector<spirv::Extension, 4> elidedExts;

// Relax the bitwidth requirements for capabilities and extensions
if (options.emulateLT32BitScalarTypes) {
elidedCaps.push_back(spirv::Capability::Int8);
elidedCaps.push_back(spirv::Capability::Int16);
elidedCaps.push_back(spirv::Capability::Float16);
}
// For capabilities whose requirements were relaxed, relax requirements for
// the extensions that were infered by those capabilities (e.g., elidedCaps)
for (spirv::Capability cap : elidedCaps) {
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
if (ext)
llvm::append_range(elidedExts, *ext);
}
// If all requirements are met, then we can accept this type as-is.
if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
if (succeeded(checkCapabilityAndExtensionRequirements(
type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
return type;

auto elementType =
convertScalarType(targetEnv, options, scalarType, storageClass);
if (elementType)
return VectorType::get(type.getShape(), elementType);
return nullptr;
}

Expand Down Expand Up @@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
SmallVector<ArrayRef<spirv::Capability>, 2> caps;
scalarType.getExtensions(exts);
scalarType.getCapabilities(caps);
if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
failed(checkExtensionRequirements(type, targetEnv, exts))) {

if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
exts))) {
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return castOp.getResult(0);
}
Expand Down Expand Up @@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
typeExtensions)))
return false;

typeCapabilities.clear();
cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
typeCapabilities)))
typeExtensions.clear();
cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
// Checking for capability and extension requirements along with capability
// infered extensions.
// If a capability is present, the extension that
// supports it should also be present, this reduces the burden of adding
// extension requirement that may or maynot be added in
// CompositeType::getExtensions().
if (failed(checkCapabilityAndExtensionRequirements(
op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
return false;
}

Expand Down
Loading