Skip to content

Commit

Permalink
Merge pull request microsoft#4925 from microsoft/user/dwayner/Iron
Browse files Browse the repository at this point in the history
ORT DirectML EP for Iron release, ONNX 1.5
  • Loading branch information
fdwr authored Aug 28, 2020
2 parents 1281ff6 + 79429c9 commit 040c5fa
Show file tree
Hide file tree
Showing 60 changed files with 2,954 additions and 875 deletions.
2 changes: 1 addition & 1 deletion cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.2.1.0)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.3.0.0)

# Restore nuget packages, which will pull down the DirectML redist package
add_custom_command(
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/dml/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Readability matters. Prevent syntax noise in pull requests for people who
# accidentally leave enabled the auto-formatting options in Visual Studio.
DisableFormat: true
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,16 @@ namespace Windows::AI::MachineLearning::Adapter
const void* executionHandle,
DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;

struct GraphNodeFactoryRegistration
{
GraphNodeFactory factory;
std::optional<uint32_t> requiredInputCount;

// The operator inputs/outputs must be a floating point data type. When true,
// if the node's tensor data type is not-floating point, the node is partioned
// separately (unless the input/output is a CPU constant input, which is okay,
// as those can be read directly by the DML operator in the DML_OPERATOR_DESC).
bool requiresFloatFormatsExceptConstInputs = false;
};

Expand All @@ -109,6 +114,20 @@ namespace Windows::AI::MachineLearning::Adapter
std::vector<uint32_t> requiredConstantCpuInputs;
std::optional<GraphNodeFactoryRegistration> graphNodeFactoryRegistration;
KernelSupportQuery supportQuery;

// Many ONNX operators use 64-bit tensors, but most DML operators only support
// 32-bit indices. This flag indicates to the graph whether it's okay to compute
// the result using 32-bit tensors (ignoring the upper bits) via doubled strides.
bool supportedWith64BitTensorsVia32BitStrides = false;

// When true, the input to the current operator may come from any execution
// provider. Otherwise it must have come from another DML node to assume it's safe
// to use 64-bit to 32-bit striding.
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;

// Operator supports true 64-bit tensors directly, no strides needed.
// So fallback to strided 32-bit only occurs when the device lacks 64-bit support.
bool prefer64BitTensorsDirectly = false;
};

using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
bool requiresFloatFormatsForGraph,
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept try
{
Expand Down Expand Up @@ -456,6 +459,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
{
auto regInfo = std::make_shared<InternalRegistrationInfo>();
regInfo->requiredConstantCpuInputs = constantCpuInputCapture;
regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides;
regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp;
regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly;

// Only internal operators support usage in DML graphs
if (supportsGraph)
Expand Down Expand Up @@ -527,8 +533,14 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
else
{
// Currently unsupported for external operators
if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph ||
requiresFloatFormatsForGraph || requiredConstantCpuInputs)
if (canAliasFirstInput ||
supportsGraph ||
requiredInputCountForGraph ||
requiresFloatFormatsForGraph ||
requiredConstantCpuInputs ||
supportedWith64BitTensorsVia32BitStrides ||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
prefer64BitTensorsDirectly)
{
THROW_HR(E_INVALIDARG);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
bool requiresFloatFormatsForGraph = false,
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0) const noexcept override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,42 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
case MLOperatorTensorDataType::UInt16: return DML_TENSOR_DATA_TYPE_UINT16;
case MLOperatorTensorDataType::Int16: return DML_TENSOR_DATA_TYPE_INT16;
case MLOperatorTensorDataType::Int32: return DML_TENSOR_DATA_TYPE_INT32;
case MLOperatorTensorDataType::Int64: return DML_TENSOR_DATA_TYPE_UINT32;
case MLOperatorTensorDataType::Int64: return DML_TENSOR_DATA_TYPE_INT64;
case MLOperatorTensorDataType::String: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Bool: return DML_TENSOR_DATA_TYPE_UINT8;
case MLOperatorTensorDataType::Float16: return DML_TENSOR_DATA_TYPE_FLOAT16;
case MLOperatorTensorDataType::Double: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::UInt32: return DML_TENSOR_DATA_TYPE_UINT32;
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT32; // Stride is used to access lower 32-bits.
case MLOperatorTensorDataType::UInt64: return DML_TENSOR_DATA_TYPE_UINT64;
case MLOperatorTensorDataType::Complex64: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Complex128: return DML_TENSOR_DATA_TYPE_UNKNOWN;
case MLOperatorTensorDataType::Undefined:
default: return DML_TENSOR_DATA_TYPE_UNKNOWN;;
default: return DML_TENSOR_DATA_TYPE_UNKNOWN;
};
}

DML_TENSOR_DATA_TYPE Remap64bitDmlDataTypeTo32bit(DML_TENSOR_DATA_TYPE dmlElementType) noexcept
{
switch (dmlElementType)
{
case DML_TENSOR_DATA_TYPE_UINT64: return DML_TENSOR_DATA_TYPE_UINT32; break;
case DML_TENSOR_DATA_TYPE_INT64: return DML_TENSOR_DATA_TYPE_INT32; break;
default: return dmlElementType;
}
}

bool IsSigned(DML_TENSOR_DATA_TYPE dataType)
{
switch (dataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT64: return true;
case DML_TENSOR_DATA_TYPE_FLOAT32: return true;
case DML_TENSOR_DATA_TYPE_FLOAT16: return true;
case DML_TENSOR_DATA_TYPE_UINT64: return false;
case DML_TENSOR_DATA_TYPE_UINT32: return false;
case DML_TENSOR_DATA_TYPE_UINT16: return false;
case DML_TENSOR_DATA_TYPE_UINT8: return false;
case DML_TENSOR_DATA_TYPE_INT64: return true;
case DML_TENSOR_DATA_TYPE_INT32: return true;
case DML_TENSOR_DATA_TYPE_INT16: return true;
case DML_TENSOR_DATA_TYPE_INT8: return true;
Expand Down Expand Up @@ -70,9 +83,14 @@ MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tenso
case DML_TENSOR_DATA_TYPE_INT32: return MLOperatorTensorDataType::Int32;
case DML_TENSOR_DATA_TYPE_FLOAT16: return MLOperatorTensorDataType::Float16;
case DML_TENSOR_DATA_TYPE_UINT32: return MLOperatorTensorDataType::UInt32;
case DML_TENSOR_DATA_TYPE_UINT64: return MLOperatorTensorDataType::UInt64;
case DML_TENSOR_DATA_TYPE_INT64: return MLOperatorTensorDataType::Int64;
case DML_TENSOR_DATA_TYPE_FLOAT64: return MLOperatorTensorDataType::Double;

default: ML_INVALID_ARGUMENT("Unknown DML_TENSOR_DATA_TYPE.");
};
}

size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType)
{
return ComputeElementCountFromDimensions(dimensions) * GetByteSizeFromMlDataType(tensorDataType);
Expand All @@ -90,4 +108,40 @@ size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor)
return ComputeByteSizeFromDimensions(gsl::make_span(dimensions.data(), dimensionCount), tensor.GetTensorDataType());
}

uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice)
{
uint32_t deviceTypeMask = 0u;

// Form the bitmask of all supported data types.
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};

THROW_IF_FAILED(dmlDevice->CheckFeatureSupport(
DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT,
sizeof(dataTypeQuery),
&dataTypeQuery,
sizeof(dataTypeSupport),
&dataTypeSupport
));

deviceTypeMask |= (dataTypeSupport.IsSupported << i);
}

return deviceTypeMask;
}

void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides)
{
assert(sizes.size() == strides.size());

uint32_t stride = 1;
for (size_t i = strides.size(); i-- > 0; )
{
strides[i] = stride;
stride *= sizes[i];
}
}

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ namespace Dml
{
using namespace OperatorHelper;

static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX;
static const int MaximumDimensionCount = DML_TENSOR_DIMENSION_COUNT_MAX1;

DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataType(MLOperatorTensorDataType tensorDataType);
DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataType tensorDataType) noexcept;
DML_TENSOR_DATA_TYPE Remap64bitDmlDataTypeTo32bit(DML_TENSOR_DATA_TYPE dmlElementType) noexcept;
MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tensorDataType);
size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType);
size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor);
uint32_t GetSupportedDeviceDataTypeMask(IDMLDevice* dmlDevice);
void GetDescendingPackedStrides(gsl::span<const uint32_t> sizes, /*out*/ gsl::span<uint32_t> strides);

bool IsSigned(DML_TENSOR_DATA_TYPE dataType);

Expand All @@ -40,6 +43,12 @@ namespace Dml
UINT elementSizeInBytes = 0;
switch (dataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT64:
case DML_TENSOR_DATA_TYPE_UINT64:
case DML_TENSOR_DATA_TYPE_INT64:
elementSizeInBytes = 8;
break;

case DML_TENSOR_DATA_TYPE_FLOAT32:
case DML_TENSOR_DATA_TYPE_UINT32:
case DML_TENSOR_DATA_TYPE_INT32:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,9 @@ namespace Dml
{
assert(!m_closed);

const size_t sourceSizeInBytes = ComputeByteSizeFromTensor(*src);
const size_t dataSizeInBytes = ComputeByteSizeFromTensor(*dst);
THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != ComputeByteSizeFromTensor(*src)); // Tensors must be the same size
THROW_HR_IF(E_INVALIDARG, dataSizeInBytes != sourceSizeInBytes); // Tensors must be the same size

if (dataSizeInBytes == 0)
{
Expand Down Expand Up @@ -461,7 +462,7 @@ namespace Dml
}
CATCH_RETURN();

uint32_t ExecutionProviderImpl::GetSuppportedDeviceDataTypeMask() const
uint32_t ExecutionProviderImpl::GetSupportedDeviceDataTypeMask() const
{
// The DML provider registers all supported kernels up-front regardless of actual device capability,
// but this is problematic later when executing the graph because DirectML will fail to create
Expand All @@ -470,26 +471,7 @@ namespace Dml
// handle them, similar to the fallback in CUDAExecutionProvider::GetCapability for certain RNN/GRU/Conv
// attributes.

uint32_t deviceTypeMask = 0u;

// Form the bitmask of all supported data types.
for (uint32_t i = 0; i <= DML_TENSOR_DATA_TYPE_INT8; ++i)
{
DML_FEATURE_QUERY_TENSOR_DATA_TYPE_SUPPORT dataTypeQuery = { static_cast<DML_TENSOR_DATA_TYPE>(i) };
DML_FEATURE_DATA_TENSOR_DATA_TYPE_SUPPORT dataTypeSupport = {};

THROW_IF_FAILED(m_dmlDevice->CheckFeatureSupport(
DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT,
sizeof(dataTypeQuery),
&dataTypeQuery,
sizeof(dataTypeSupport),
&dataTypeSupport
));

deviceTypeMask |= (dataTypeSupport.IsSupported << i);
}

return deviceTypeMask;
return Dml::GetSupportedDeviceDataTypeMask(m_dmlDevice.Get());
}

std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
Expand All @@ -498,7 +480,7 @@ namespace Dml
const std::vector<const onnxruntime::KernelRegistry*>& registries) const
{
std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_";
uint32_t deviceDataTypeMask = GetSuppportedDeviceDataTypeMask();
uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask();

return PartitionGraph(
graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ using Base = Microsoft::WRL::RuntimeClass<
TInterfaces...>;
}

using namespace Microsoft::WRL;

namespace Dml
{
using Microsoft::WRL::ComPtr;
class PooledUploadHeap;
class ReadbackHeap;
class ExecutionContext;
Expand Down Expand Up @@ -87,7 +86,7 @@ namespace Dml
const std::vector<const onnxruntime::KernelRegistry*>& registries
) const;

uint32_t GetSuppportedDeviceDataTypeMask() const;
uint32_t GetSupportedDeviceDataTypeMask() const;

onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
onnxruntime::common::Status CopyTensors(const std::vector<onnxruntime::IDataTransfer::SrcDstPair>& src_dst_pairs) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ union ActivationOperatorDescUnion
{
DML_ACTIVATION_IDENTITY_OPERATOR_DESC identity;
DML_ACTIVATION_ELU_OPERATOR_DESC elu;
DML_ACTIVATION_CELU_OPERATOR_DESC celu;
DML_ACTIVATION_HARDMAX_OPERATOR_DESC hardmax;
DML_ACTIVATION_HARD_SIGMOID_OPERATOR_DESC hardSigmoid;
DML_ACTIVATION_LEAKY_RELU_OPERATOR_DESC leakyRelu;
Expand Down Expand Up @@ -36,6 +37,7 @@ struct ActivationOperatorDesc
switch (activationType)
{
case DML_OPERATOR_ACTIVATION_ELU: return { activationType, &params.elu };
case DML_OPERATOR_ACTIVATION_CELU: return { activationType, &params.celu };
case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, &params.hardmax };
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, &params.sigmoid };
case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, &params.identity };
Expand Down
Loading

0 comments on commit 040c5fa

Please sign in to comment.