Skip to content

Commit

Permalink
Validate GroupNonUniform instructions (KhronosGroup#5296)
Browse files Browse the repository at this point in the history
Fixes KhronosGroup#5283

* Validate group non-uniform instructions
  • Loading branch information
alan-baker authored Jul 11, 2023
1 parent 4594ffc commit 0530a53
Show file tree
Hide file tree
Showing 5 changed files with 1,094 additions and 9 deletions.
347 changes: 346 additions & 1 deletion source/val/validate_non_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,207 @@ namespace spvtools {
namespace val {
namespace {

spv_result_t ValidateGroupNonUniformElect(ValidationState_t& _,
const Instruction* inst) {
if (!_.IsBoolScalarType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a boolean scalar type";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformAnyAll(ValidationState_t& _,
const Instruction* inst) {
if (!_.IsBoolScalarType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a boolean scalar type";
}

if (!_.IsBoolScalarType(_.GetOperandTypeId(inst, 3))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Predicate must be a boolean scalar type";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformAllEqual(ValidationState_t& _,
const Instruction* inst) {
if (!_.IsBoolScalarType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a boolean scalar type";
}

const auto value_type = _.GetOperandTypeId(inst, 3);
if (!_.IsFloatScalarOrVectorType(value_type) &&
!_.IsIntScalarOrVectorType(value_type) &&
!_.IsBoolScalarOrVectorType(value_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Value must be a scalar or vector of integer, floating-point, or "
"boolean type";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformBroadcastShuffle(ValidationState_t& _,
const Instruction* inst) {
const auto type_id = inst->type_id();
if (!_.IsFloatScalarOrVectorType(type_id) &&
!_.IsIntScalarOrVectorType(type_id) &&
!_.IsBoolScalarOrVectorType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a scalar or vector of integer, floating-point, "
"or boolean type";
}

const auto value_type_id = _.GetOperandTypeId(inst, 3);
if (value_type_id != type_id) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The type of Value must match the Result type";
}

const auto GetOperandName = [](const spv::Op opcode) {
std::string operand;
switch (opcode) {
case spv::Op::OpGroupNonUniformBroadcast:
case spv::Op::OpGroupNonUniformShuffle:
operand = "Id";
break;
case spv::Op::OpGroupNonUniformShuffleXor:
operand = "Mask";
break;
case spv::Op::OpGroupNonUniformQuadBroadcast:
operand = "Index";
break;
case spv::Op::OpGroupNonUniformQuadSwap:
operand = "Direction";
break;
case spv::Op::OpGroupNonUniformShuffleUp:
case spv::Op::OpGroupNonUniformShuffleDown:
default:
operand = "Delta";
break;
}
return operand;
};

const auto id_type_id = _.GetOperandTypeId(inst, 4);
if (!_.IsUnsignedIntScalarType(id_type_id)) {
std::string operand = GetOperandName(inst->opcode());
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< operand << " must be an unsigned integer scalar";
}

const bool should_be_constant =
inst->opcode() == spv::Op::OpGroupNonUniformQuadSwap ||
((inst->opcode() == spv::Op::OpGroupNonUniformBroadcast ||
inst->opcode() == spv::Op::OpGroupNonUniformQuadBroadcast) &&
_.version() < SPV_SPIRV_VERSION_WORD(1, 5));
if (should_be_constant) {
const auto id_id = inst->GetOperandAs<uint32_t>(4);
const auto id_op = _.GetIdOpcode(id_id);
if (!spvOpcodeIsConstant(id_op)) {
std::string operand = GetOperandName(inst->opcode());
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Before SPIR-V 1.5, " << operand
<< " must be a constant instruction";
}
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformBroadcastFirst(ValidationState_t& _,
const Instruction* inst) {
const auto type_id = inst->type_id();
if (!_.IsFloatScalarOrVectorType(type_id) &&
!_.IsIntScalarOrVectorType(type_id) &&
!_.IsBoolScalarOrVectorType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a scalar or vector of integer, floating-point, "
"or boolean type";
}

const auto value_type_id = _.GetOperandTypeId(inst, 3);
if (value_type_id != type_id) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The type of Value must match the Result type";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformBallot(ValidationState_t& _,
const Instruction* inst) {
if (!_.IsUnsignedIntVectorType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a 4-component unsigned integer vector";
}

if (_.GetDimension(inst->type_id()) != 4) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a 4-component unsigned integer vector";
}

const auto pred_type_id = _.GetOperandTypeId(inst, 3);
if (!_.IsBoolScalarType(pred_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Predicate must be a boolean scalar";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformInverseBallot(ValidationState_t& _,
const Instruction* inst) {
if (!_.IsBoolScalarType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a boolean scalar";
}

const auto value_type_id = _.GetOperandTypeId(inst, 3);
if (!_.IsUnsignedIntVectorType(value_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Value must be a 4-component unsigned integer vector";
}

if (_.GetDimension(value_type_id) != 4) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Value must be a 4-component unsigned integer vector";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformBallotBitExtract(ValidationState_t& _,
const Instruction* inst) {
if (!_.IsBoolScalarType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a boolean scalar";
}

const auto value_type_id = _.GetOperandTypeId(inst, 3);
if (!_.IsUnsignedIntVectorType(value_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Value must be a 4-component unsigned integer vector";
}

if (_.GetDimension(value_type_id) != 4) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Value must be a 4-component unsigned integer vector";
}

const auto id_type_id = _.GetOperandTypeId(inst, 4);
if (!_.IsUnsignedIntScalarType(id_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Id must be an unsigned integer scalar";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformBallotBitCount(ValidationState_t& _,
const Instruction* inst) {
// Scope is already checked by ValidateExecutionScope() above.
Expand Down Expand Up @@ -60,6 +261,107 @@ spv_result_t ValidateGroupNonUniformBallotBitCount(ValidationState_t& _,
return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformBallotFind(ValidationState_t& _,
const Instruction* inst) {
if (!_.IsUnsignedIntScalarType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be an unsigned integer scalar";
}

const auto value_type_id = _.GetOperandTypeId(inst, 3);
if (!_.IsUnsignedIntVectorType(value_type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Value must be a 4-component unsigned integer vector";
}

if (_.GetDimension(value_type_id) != 4) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Value must be a 4-component unsigned integer vector";
}

return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformArithmetic(ValidationState_t& _,
const Instruction* inst) {
const bool is_unsigned = inst->opcode() == spv::Op::OpGroupNonUniformUMin ||
inst->opcode() == spv::Op::OpGroupNonUniformUMax;
const bool is_float = inst->opcode() == spv::Op::OpGroupNonUniformFAdd ||
inst->opcode() == spv::Op::OpGroupNonUniformFMul ||
inst->opcode() == spv::Op::OpGroupNonUniformFMin ||
inst->opcode() == spv::Op::OpGroupNonUniformFMax;
const bool is_bool = inst->opcode() == spv::Op::OpGroupNonUniformLogicalAnd ||
inst->opcode() == spv::Op::OpGroupNonUniformLogicalOr ||
inst->opcode() == spv::Op::OpGroupNonUniformLogicalXor;
if (is_float) {
if (!_.IsFloatScalarOrVectorType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a floating-point scalar or vector";
}
} else if (is_bool) {
if (!_.IsBoolScalarOrVectorType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be a boolean scalar or vector";
}
} else if (is_unsigned) {
if (!_.IsUnsignedIntScalarOrVectorType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be an unsigned integer scalar or vector";
}
} else if (!_.IsIntScalarOrVectorType(inst->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Result must be an integer scalar or vector";
}

const auto value_type_id = _.GetOperandTypeId(inst, 4);
if (value_type_id != inst->type_id()) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "The type of Value must match the Result type";
}

const auto group_op = inst->GetOperandAs<spv::GroupOperation>(3);
bool is_clustered_reduce = group_op == spv::GroupOperation::ClusteredReduce;
bool is_partitioned_nv =
group_op == spv::GroupOperation::PartitionedReduceNV ||
group_op == spv::GroupOperation::PartitionedInclusiveScanNV ||
group_op == spv::GroupOperation::PartitionedExclusiveScanNV;
if (inst->operands().size() <= 5) {
if (is_clustered_reduce) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "ClusterSize must be present when Operation is ClusteredReduce";
} else if (is_partitioned_nv) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ballot must be present when Operation is PartitionedReduceNV, "
"PartitionedInclusiveScanNV, or PartitionedExclusiveScanNV";
}
} else {
const auto operand_id = inst->GetOperandAs<uint32_t>(5);
const auto* operand = _.FindDef(operand_id);
if (is_partitioned_nv) {
if (!operand || !_.IsIntScalarOrVectorType(operand->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ballot must be a 4-component integer vector";
}

if (_.GetDimension(operand->type_id()) != 4) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ballot must be a 4-component integer vector";
}
} else {
if (!operand || !_.IsUnsignedIntScalarType(operand->type_id())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "ClusterSize must be an unsigned integer scalar";
}

if (!spvOpcodeIsConstant(operand->opcode())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "ClusterSize must be a constant instruction";
}
}
}
return SPV_SUCCESS;
}

spv_result_t ValidateGroupNonUniformRotateKHR(ValidationState_t& _,
const Instruction* inst) {
// Scope is already checked by ValidateExecutionScope() above.
Expand Down Expand Up @@ -120,15 +422,58 @@ spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst) {
const spv::Op opcode = inst->opcode();

if (spvOpcodeIsNonUniformGroupOperation(opcode)) {
const uint32_t execution_scope = inst->word(3);
const uint32_t execution_scope = inst->GetOperandAs<uint32_t>(2);
if (auto error = ValidateExecutionScope(_, inst, execution_scope)) {
return error;
}
}

switch (opcode) {
case spv::Op::OpGroupNonUniformElect:
return ValidateGroupNonUniformElect(_, inst);
case spv::Op::OpGroupNonUniformAny:
case spv::Op::OpGroupNonUniformAll:
return ValidateGroupNonUniformAnyAll(_, inst);
case spv::Op::OpGroupNonUniformAllEqual:
return ValidateGroupNonUniformAllEqual(_, inst);
case spv::Op::OpGroupNonUniformBroadcast:
case spv::Op::OpGroupNonUniformShuffle:
case spv::Op::OpGroupNonUniformShuffleXor:
case spv::Op::OpGroupNonUniformShuffleUp:
case spv::Op::OpGroupNonUniformShuffleDown:
case spv::Op::OpGroupNonUniformQuadBroadcast:
case spv::Op::OpGroupNonUniformQuadSwap:
return ValidateGroupNonUniformBroadcastShuffle(_, inst);
case spv::Op::OpGroupNonUniformBroadcastFirst:
return ValidateGroupNonUniformBroadcastFirst(_, inst);
case spv::Op::OpGroupNonUniformBallot:
return ValidateGroupNonUniformBallot(_, inst);
case spv::Op::OpGroupNonUniformInverseBallot:
return ValidateGroupNonUniformInverseBallot(_, inst);
case spv::Op::OpGroupNonUniformBallotBitExtract:
return ValidateGroupNonUniformBallotBitExtract(_, inst);
case spv::Op::OpGroupNonUniformBallotBitCount:
return ValidateGroupNonUniformBallotBitCount(_, inst);
case spv::Op::OpGroupNonUniformBallotFindLSB:
case spv::Op::OpGroupNonUniformBallotFindMSB:
return ValidateGroupNonUniformBallotFind(_, inst);
case spv::Op::OpGroupNonUniformIAdd:
case spv::Op::OpGroupNonUniformFAdd:
case spv::Op::OpGroupNonUniformIMul:
case spv::Op::OpGroupNonUniformFMul:
case spv::Op::OpGroupNonUniformSMin:
case spv::Op::OpGroupNonUniformUMin:
case spv::Op::OpGroupNonUniformFMin:
case spv::Op::OpGroupNonUniformSMax:
case spv::Op::OpGroupNonUniformUMax:
case spv::Op::OpGroupNonUniformFMax:
case spv::Op::OpGroupNonUniformBitwiseAnd:
case spv::Op::OpGroupNonUniformBitwiseOr:
case spv::Op::OpGroupNonUniformBitwiseXor:
case spv::Op::OpGroupNonUniformLogicalAnd:
case spv::Op::OpGroupNonUniformLogicalOr:
case spv::Op::OpGroupNonUniformLogicalXor:
return ValidateGroupNonUniformArithmetic(_, inst);
case spv::Op::OpGroupNonUniformRotateKHR:
return ValidateGroupNonUniformRotateKHR(_, inst);
default:
Expand Down
17 changes: 17 additions & 0 deletions source/val/validation_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,23 @@ bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
return false;
}

bool ValidationState_t::IsUnsignedIntScalarOrVectorType(uint32_t id) const {
const Instruction* inst = FindDef(id);
if (!inst) {
return false;
}

if (inst->opcode() == spv::Op::OpTypeInt) {
return inst->GetOperandAs<uint32_t>(2) == 0;
}

if (inst->opcode() == spv::Op::OpTypeVector) {
return IsUnsignedIntScalarType(GetComponentType(id));
}

return false;
}

bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const {
const Instruction* inst = FindDef(id);
return inst && inst->opcode() == spv::Op::OpTypeInt && inst->word(3) == 1;
Expand Down
Loading

0 comments on commit 0530a53

Please sign in to comment.