Skip to content

Commit

Permalink
DirectML EP remove stale code for int64 via int32 double strides (#9959)
Browse files Browse the repository at this point in the history
  • Loading branch information
fdwr authored Jan 10, 2022
1 parent 1f5b073 commit 0f5e82c
Show file tree
Hide file tree
Showing 21 changed files with 111 additions and 391 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,6 @@ namespace Windows::AI::MachineLearning::Adapter
std::vector<uint32_t> requiredConstantCpuInputs;
std::optional<GraphNodeFactoryRegistration> graphNodeFactoryRegistration;
KernelSupportQuery supportQuery;

// Many ONNX operators use 64-bit tensors, but most DML operators only support
// 32-bit indices. This flag indicates to the graph whether it's okay to compute
// the result using 32-bit tensors (ignoring the upper bits) via doubled strides.
bool supportedWith64BitTensorsVia32BitStrides = false;

// When true, the input to the current operator may come from any execution
// provider. Otherwise it must have come from another DML node to assume it's safe
// to use 64-bit to 32-bit striding.
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;

// Operator supports true 64-bit tensors directly, no strides needed.
// So fallback to strided 32-bit only occurs when the device lacks 64-bit support.
bool prefer64BitTensorsDirectly = false;

// The operator supports emulation for uint64/int64 even if the hardware doesn't
// support native uint64/int64 data types.
bool support64BitTensorsViaEmulation = false;
};

using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
bool support64BitTensorsViaEmulation,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept
{
Expand Down Expand Up @@ -471,10 +467,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
{
auto regInfo = std::make_shared<InternalRegistrationInfo>();
regInfo->requiredConstantCpuInputs = constantCpuInputCapture;
regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides;
regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp;
regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly;
regInfo->support64BitTensorsViaEmulation = support64BitTensorsViaEmulation;

// Only internal operators support usage in DML graphs
if (supportsGraph)
Expand Down Expand Up @@ -546,11 +538,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
if (canAliasFirstInput ||
supportsGraph ||
requiredInputCountForGraph ||
requiredConstantCpuInputs ||
supportedWith64BitTensorsVia32BitStrides ||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
prefer64BitTensorsDirectly ||
support64BitTensorsViaEmulation)
requiredConstantCpuInputs)
{
ORT_THROW_HR(E_INVALIDARG);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
bool support64BitTensorsViaEmulation = false,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0) const noexcept override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,78 +139,6 @@ namespace Dml
}
};

bool NodeArgSupportedInGraph(
const onnxruntime::NodeArg* arg,
bool supports64BitTensorsViaEmulation,
uint32_t supportedDeviceDataTypeMask
)
{
if (arg->Exists())
{
const onnx::TypeProto* typeProto = arg->TypeAsProto();
if (typeProto->value_case() == onnx::TypeProto::kTensorType)
{
const onnx::TypeProto_Tensor tensorType = typeProto->tensor_type();
if (tensorType.has_elem_type())
{
// TODO: Remove this by handling zeroing on the output of fused graph nodes and handling of non-float
// types in DML's identity operator, which is used for strided copies.

MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type()));

// Do not include operators in the graph if tensor types are unsupported,
// except cases that are always supported via emulation.
if ((mlDataType == MLOperatorTensorDataType::UInt64 ||
mlDataType == MLOperatorTensorDataType::Int64) &&
!supports64BitTensorsViaEmulation)
{
constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64);
if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit)
{
return false;
}
}

}
}
}

return true;
}

bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration, uint32_t supportedDeviceDataTypeMask)
{
for (size_t i = 0; i < node.InputDefs().size(); ++i)
{
bool isConstantCpuInput = std::find(registration.requiredConstantCpuInputs.begin(), registration.requiredConstantCpuInputs.end(), i) !=
registration.requiredConstantCpuInputs.end();

if (!isConstantCpuInput &&
!NodeArgSupportedInGraph(
node.InputDefs()[i],
registration.support64BitTensorsViaEmulation,
supportedDeviceDataTypeMask
))
{
return false;
}
}

for (auto arg : node.OutputDefs())
{
if (!NodeArgSupportedInGraph(
arg,
registration.support64BitTensorsViaEmulation,
supportedDeviceDataTypeMask
))
{
return false;
}
}

return true;
}

bool TryGetTensorDataType(
const onnxruntime::NodeArg& nodeArg,
_Out_ MLOperatorTensorDataType* onnxElementType
Expand Down Expand Up @@ -242,26 +170,10 @@ namespace Dml
{
ORT_THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);

bool prefer64BitTensorsDirectly = false;
bool support64BitTensorsViaEmulation = false;
bool supportedWith64BitTensorsVia32BitStrides = false;
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;

if (regInfo != nullptr)
{
// Read the operator flags for handling 64-bit tensors and whether it's allowed to fall back
// to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false
// in this particular call, then the operator-specific flags do not matter as the caller has
// disabled 64-bit support.
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
support64BitTensorsViaEmulation = regInfo->support64BitTensorsViaEmulation;
if (allow64BitInputThroughStrides)
{
supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp;
supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp;
}

// Collect the list of CPU-bound input tensors, needed when checking 64-bit fallback
// or for other data types like int-8 which may be supported for CPU inputs but not
// GPU inputs.
Expand Down Expand Up @@ -317,55 +229,7 @@ namespace Dml
return;
}

// If this operator implements 64-bit support in terms of strided 32-bit tensors,
// then the data type needs to be remapped, regardless of whether input or output.
//
// Some operators can fairly safely implement 64-bit tensors in terms of
// strided 32-bit tensors regardless of input tensor's execution provider
// because the indices measure along a single axis and should fall within
// the range of an int32/uint32.
//
// Currently all DML kernels outputting int64 and uint64 are expected to
// not *introduce* values out of range, which allows the temporary trick
// using strides to emulate 64 bit tensors to work. If the source is a CPU
// operator, graph input or initializer, it's not safe to assume the input
// can be represented with 32 bits.
//
bool isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;
bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64);
if (is64BitIntType)
{
if (support64BitTensorsViaEmulation)
{
// Consider it supported regardless of hardware support.
isDataTypeSupported = true;
}
else if (prefer64BitTensorsDirectly && isDataTypeSupported)
{
// Operator supports native int64/uint64 tensors.
}
else if (supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp)
{
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;

if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
{
// Look up the input partition. If it's a graph input or initializer it will be missing
// from the partition map.
const std::string& argName = nodeArg.Name();

// If input tensor's data comes from the output of a different execution provider,
// consider it unsafe to apply fallback to.
auto partitionIter = nodeNameToPartitionMap->find(argName);
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
{
nodeContainsSupportedDataTypes = false;
return;
}
}
}
}

// Reject node if the data type is unsupported by the device.
if (!isDataTypeSupported)
Expand Down Expand Up @@ -465,8 +329,7 @@ namespace Dml
{
auto internalRegInfo = regInfoIter->second;

if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration &&
NodeTensorTypesSupportedInGraph(node, *internalRegInfo, supportedDeviceDataTypeMask))
if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration)
{
bool requiredCpuInputsConstant = true;
for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,34 +389,6 @@ namespace Dml
));
}

void DmlOperator::Remap64bitDmlDataTypesTo32bit()
{
for (auto& tensor : m_inputTensorDescs)
{
tensor.Remap64bitDmlDataTypeTo32bit();
}

for (auto& tensor : m_outputTensorDescs)
{
tensor.Remap64bitDmlDataTypeTo32bit();
}
}

void DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded()
{
// Conditionally remap 64-bit data types to strided 32-bit if DML does not
// support 64-bit data types directly on the device.

uint32_t deviceTypeMask = Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get());
uint32_t deviceTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_INT64) | (1 << DML_TENSOR_DATA_TYPE_UINT64);

// If the device doesn't support 64-bit tensors, fall back to 32-bit with strides.
if (!(deviceTypeMask & deviceTypeMask64bit))
{
Remap64bitDmlDataTypesTo32bit();
}
}

TensorDesc DmlOperator::CreateTensorDescFromInput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,6 @@ namespace Dml

void ExecuteZeroInt64Tensor(IDMLCompiledOperator* compiledOperator, IMLOperatorTensor* tensor);

// Remap 64-bit data types to 32-bit via doubled strides.
// These should be called before GetDmlInputDescs or GetDmlOutputDescs.
void Remap64bitDmlDataTypesTo32bit();
void Remap64bitDmlDataTypesTo32bitIfNeeded();

TensorDesc CreateTensorDescFromInput(
const MLOperatorKernelCreationContext& kernelInfo,
uint32_t index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ class DmlOperatorGather : public DmlOperator, public GatherHelper
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));

DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
Expand Down Expand Up @@ -62,8 +60,6 @@ class DmlOperatorGatherElements : public DmlOperator
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));

DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
Expand Down Expand Up @@ -101,8 +97,6 @@ class DmlOperatorGatherNd : public DmlOperator, public GatherNdHelper
size_t dimensionCountMax = std::max({dataDimensions.size(), indicesDimensions.size(), outputDimensions.size()});
DmlOperator::Initialize(kernelCreationContext, gsl::narrow_cast<uint32_t>(dimensionCountMax));

DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class DmlOperatorMaxUnpool : public DmlOperator, public UnpoolingHelper
std::vector<std::optional<uint32_t>> inputIndices = { 0, 1 }; // The 3rd tensor ('output_shape') is not bound, just 'X' and 'I' indices.
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices);
DmlOperator::Remap64bitDmlDataTypesTo32bit();
m_inputTensorDescs[1].ForceUnsignedDataType(); // MaxUnpool accepts uint32_t.

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class DmlOperatorOneHot : public DmlOperator, OneHotHelper
0
);

DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();

// Adjust the axis so it's in DML's terms rather than the original ONNX indexing.
uint32_t dmlAxis = GetDmlAdjustedAxis(
m_absoluteAxis,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase

if (hasOutputIndices)
{
DmlOperator::Remap64bitDmlDataTypesTo32bit();
m_outputTensorDescs[1].ForceUnsignedDataType(); // MaxPool accepts uint32_t.
m_outputTensorDescs[1].ForceUnsignedDataType(); // MaxPool accepts uint32_t/uint64_t.
desc.OutputIndicesTensor = &outputDescs[1];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ class DmlOperatorReduce : public DmlOperator, public ReduceHelperBase
argmaxDesc.Axes = dmlAxes.data();
argmaxDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());

// If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits
// of each element. If the device directly supports 64-bit elements, then no need.
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
if (m_outputTensorDescs[0].WasRemapped64bitTo32bit())
{
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
}

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMAX, &argmaxDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
Expand All @@ -87,14 +79,6 @@ class DmlOperatorReduce : public DmlOperator, public ReduceHelperBase
argminDesc.Axes = dmlAxes.data();
argminDesc.AxisCount = gsl::narrow_cast<uint32_t>(dmlAxes.size());

// If the 64-bit tensors were remapped to 32-bit, then we need to clear the upper 32-bits
// of each element. If the device directly supports 64-bit elements, then no need.
DmlOperator::Remap64bitDmlDataTypesTo32bitIfNeeded();
if (m_outputTensorDescs[0].WasRemapped64bitTo32bit())
{
m_zeroOperator = InitializeZeroInt64Tensor(m_outputTensorDescs[0].GetBufferSizeInBytes());
}

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ARGMIN, &argminDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
Expand All @@ -117,20 +101,12 @@ class DmlOperatorReduce : public DmlOperator, public ReduceHelperBase
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensorsForExecute(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);

if (m_zeroOperator)
{
ExecuteZeroInt64Tensor(m_zeroOperator.Get(), outputTensors[0]);
}

ORT_THROW_IF_FAILED(m_executionProvider->ExecuteOperator(
m_compiledOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(inputTensors),
gsl::make_span(outputTensors)));
}

private:
ComPtr<IDMLCompiledOperator> m_zeroOperator;
};

// A specific type of operation for registration.
Expand Down
Loading

0 comments on commit 0f5e82c

Please sign in to comment.