Skip to content

Commit

Permalink
User/sheilk/sequence fix (#15239)
Browse files Browse the repository at this point in the history
Ensure that Loop operators run on CPU.
Fix memcpy for Sequence Tensors, so that empty sequences (like when
SequenceEmpty runs on DirectML) can be copied back to CPU.
  • Loading branch information
smk2007 authored Mar 31, 2023
1 parent c06ab5e commit 7ccdf9a
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ namespace Dml
}

dataSizesInBytes.push_back(static_cast<uint32_t>(ComputeByteSizeFromTensor(*dst[i])));
ORT_THROW_HR_IF(E_INVALIDARG, dataSizesInBytes[i] != ComputeByteSizeFromTensor(*src[i])); // Tensors must be the same size
ORT_THROW_HR_IF(E_INVALIDARG, dataSizesInBytes.back() != ComputeByteSizeFromTensor(*src[i])); // Tensors must be the same size

dstDatas.push_back(dst[i]->GetData());
const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src[i]).GetDataInterface().Get());
Expand Down Expand Up @@ -621,7 +621,23 @@ namespace Dml
"SequenceEmpty",
"SequenceLength",
"SequenceErase",
"SequenceInsert",
"SequenceInsert"
};

for (auto& sequence_op : sequence_ops)
{
if (strcmp(sequence_op, node.OpType().c_str()) == 0)
{
return true;
}
}
return false;
}

bool IsDmlSequenceOperator(const onnxruntime::Node& node)
{
auto sequence_ops = std::array<char*, 1>{
"ConcatFromSequence"
};

for (auto& sequence_op : sequence_ops)
Expand Down Expand Up @@ -683,9 +699,12 @@ namespace Dml
}

// Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels.
if (edgeType == MLOperatorEdgeType::SequenceTensor && IsCpuOnDmlOperator(node))
if (edgeType == MLOperatorEdgeType::SequenceTensor)
{
// Leave nodeContainsSupportedDataTypes alone.
if (!IsCpuOnDmlOperator(node) && !IsDmlSequenceOperator(node))
{
nodeContainsSupportedDataTypes = false;
}
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,10 @@ namespace Windows::AI::MachineLearning::Adapter
}

template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetSequenceInputInfo(
uint32_t inputIndex,
uint32_t* inputCount,
MLOperatorTensorDataType* dataType) const noexcept
{
ORT_TRY
{
Expand All @@ -1068,7 +1071,7 @@ namespace Windows::AI::MachineLearning::Adapter
auto inputTensorSeq = m_kernelContext->Input<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(inputIndex));
ML_CHECK_BOOL(inputTensorSeq != nullptr);
*inputCount = static_cast<uint32_t>(inputTensorSeq->Size());

*dataType = ToMLTensorDataType(inputTensorSeq->DataType());
return S_OK;
}
ORT_CATCH_RETURN
Expand Down Expand Up @@ -1975,6 +1978,28 @@ namespace Windows::AI::MachineLearning::Adapter
ORT_CATCH_RETURN
}

HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::PrepareSequenceOutput(
uint32_t outputIndex,
MLOperatorTensorDataType dataType) const noexcept
{
ORT_TRY
{
VerifyNotClosed();

auto opKernelContextWrapper = const_cast<OpKernelContextWrapper*>(this);

ML_CHECK_BOOL(outputIndex < m_outputTensors.size());
auto outputTensorSeq = m_impl->Output<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(outputIndex));
ML_CHECK_BOOL(outputTensorSeq != nullptr);

auto mlDataType = ToMLDataType(MLOperatorEdgeType::Primitive, dataType);
outputTensorSeq->SetType(mlDataType);

return S_OK;
}
ORT_CATCH_RETURN
}

HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::GetSequenceOutputTensor(
uint32_t outputIndex,
uint32_t sequenceIndex,
Expand Down Expand Up @@ -2054,7 +2079,7 @@ namespace Windows::AI::MachineLearning::Adapter
ORT_CATCH_RETURN
}

HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept
HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::GetSequenceInputInfo(uint32_t inputIndex, uint32_t* inputCount, MLOperatorTensorDataType* dataType) const noexcept
{
ORT_TRY
{
Expand All @@ -2067,6 +2092,7 @@ namespace Windows::AI::MachineLearning::Adapter
auto inputTensorSeq = m_impl->Input<onnxruntime::TensorSeq>(gsl::narrow_cast<int>(inputIndex));
ML_CHECK_BOOL(inputTensorSeq != nullptr);
*inputCount = static_cast<uint32_t>(inputTensorSeq->Size());
*dataType = ToMLTensorDataType(inputTensorSeq->DataType());
return S_OK;
}
ORT_CATCH_RETURN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class OpNodeInfoWrapper : public Base1_t, public Base2_t, public Closable
HRESULT STDMETHODCALLTYPE GetInputTensorDimensionCount(uint32_t inputIndex, uint32_t* dimensionCount) const noexcept;
HRESULT STDMETHODCALLTYPE GetInputTensorShape(uint32_t inputIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept;

HRESULT STDMETHODCALLTYPE GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept;
HRESULT STDMETHODCALLTYPE GetSequenceInputInfo(uint32_t inputIndex, uint32_t* inputCount, MLOperatorTensorDataType* dataType) const noexcept;
HRESULT STDMETHODCALLTYPE GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex, uint32_t* dimensionCount) const noexcept;
HRESULT STDMETHODCALLTYPE GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex, uint32_t dimensionCount, uint32_t* dimensions) const noexcept;

Expand Down Expand Up @@ -480,9 +480,13 @@ class OpKernelContextWrapper : public WRL::Base<IMLOperatorKernelContext, IMLOpe
OpKernelContextWrapper(onnxruntime::OpKernelContext* context, const onnxruntime::IExecutionProvider* provider, bool isInternalOperator, const EdgeShapes* outputShapes);

bool STDMETHODCALLTYPE IsSequenceInputTensor(uint32_t inputIndex) const noexcept override;
HRESULT STDMETHODCALLTYPE GetSequenceInputCount(uint32_t inputIndex, uint32_t* inputCount) const noexcept override;
HRESULT STDMETHODCALLTYPE GetSequenceInputInfo(uint32_t inputIndex, uint32_t* inputCount, MLOperatorTensorDataType* dataType) const noexcept override;
HRESULT STDMETHODCALLTYPE GetSequenceInputTensor(uint32_t inputIndex, uint32_t sequenceIndex, IMLOperatorTensor** tensor) const noexcept override;

HRESULT STDMETHODCALLTYPE PrepareSequenceOutput(
uint32_t outputIndex,
MLOperatorTensorDataType dataType) const noexcept override;

HRESULT STDMETHODCALLTYPE GetSequenceOutputTensor(
uint32_t outputIndex,
uint32_t sequenceIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ class DmlOperatorMemcpy : public DmlOperator
void Compute(const MLOperatorKernelContext& kernelContext)
{
std::vector<IMLOperatorTensor*> inputTensors = GetInputTensors(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetInputTensors(kernelContext);
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensors(kernelContext);

if (kernelContext.IsSequenceInputTensor(0))
{
auto dataType = kernelContext.GetSequenceInputDataType(0);
kernelContext.PrepareSequenceOutput(0, dataType);

const uint32_t numTensors = kernelContext.GetSequenceInputCount(0);
inputTensors.reserve(numTensors);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,21 @@ class MLOperatorTensorShapeDescription
Microsoft::WRL::ComPtr<IMLOperatorTensorShapeDescriptionPrivate> private_impl;
m_impl.As(&private_impl);
uint32_t inputCount = 0;
ORT_THROW_IF_FAILED(private_impl->GetSequenceInputCount(inputIndex, &inputCount));
MLOperatorTensorDataType dataType;
ORT_THROW_IF_FAILED(private_impl->GetSequenceInputInfo(inputIndex, &inputCount, &dataType));
return inputCount;
}

MLOperatorTensorDataType GetSequenceInputDataType(uint32_t inputIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorTensorShapeDescriptionPrivate> private_impl;
m_impl.As(&private_impl);
uint32_t inputCount = 0;
MLOperatorTensorDataType dataType;
ORT_THROW_IF_FAILED(private_impl->GetSequenceInputInfo(inputIndex, &inputCount, &dataType));
return dataType;
}

uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorTensorShapeDescriptionPrivate> private_impl;
Expand Down Expand Up @@ -610,6 +621,12 @@ class MLOperatorKernelCreationContext : public MLOperatorAttributes
return shapeDesc.GetSequenceInputCount(inputIndex);
}

MLOperatorTensorDataType GetSequenceInputDataType(uint32_t inputIndex) const
{
auto shapeDesc = GetTensorShapeDescription();
return shapeDesc.GetSequenceInputDataType(inputIndex);
}

uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const
{
auto shapeDesc = GetTensorShapeDescription();
Expand Down Expand Up @@ -696,10 +713,21 @@ class MLShapeInferenceContext : public MLOperatorAttributes
Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContextPrivate> private_impl;
m_impl.As(&private_impl);
uint32_t inputCount = 0;
ORT_THROW_IF_FAILED(private_impl->GetSequenceInputCount(inputIndex, &inputCount));
MLOperatorTensorDataType dataType;
ORT_THROW_IF_FAILED(private_impl->GetSequenceInputInfo(inputIndex, &inputCount, &dataType));
return inputCount;
}

MLOperatorTensorDataType GetSequenceInputDataType(uint32_t inputIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContextPrivate> private_impl;
m_impl.As(&private_impl);
uint32_t inputCount = 0;
MLOperatorTensorDataType dataType;
ORT_THROW_IF_FAILED(private_impl->GetSequenceInputInfo(inputIndex, &inputCount, &dataType));
return dataType;
}

uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorShapeInferenceContextPrivate> private_impl;
Expand Down Expand Up @@ -805,10 +833,21 @@ class MLOperatorKernelContext
Microsoft::WRL::ComPtr<IMLOperatorKernelContextPrivate> operatorKernelContext;
m_impl.As(&operatorKernelContext);
uint32_t inputCount = 0;
ORT_THROW_IF_FAILED(operatorKernelContext->GetSequenceInputCount(inputIndex, &inputCount));
MLOperatorTensorDataType dataType;
ORT_THROW_IF_FAILED(operatorKernelContext->GetSequenceInputInfo(inputIndex, &inputCount, &dataType));
return inputCount;
}

MLOperatorTensorDataType GetSequenceInputDataType(uint32_t inputIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorKernelContextPrivate> operatorKernelContext;
m_impl.As(&operatorKernelContext);
uint32_t inputCount = 0;
MLOperatorTensorDataType dataType;
ORT_THROW_IF_FAILED(operatorKernelContext->GetSequenceInputInfo(inputIndex, &inputCount, &dataType));
return dataType;
}

MLOperatorTensor GetSequenceInputTensor(uint32_t inputIndex, uint32_t sequenceIndex) const
{
Microsoft::WRL::ComPtr<IMLOperatorKernelContextPrivate> operatorKernelContext;
Expand All @@ -820,6 +859,13 @@ class MLOperatorKernelContext
return tensor.Get();
}

void PrepareSequenceOutput(uint32_t outputIndex, MLOperatorTensorDataType dataType) const
{
Microsoft::WRL::ComPtr<IMLOperatorKernelContextPrivate> operatorKernelContext;
m_impl.As(&operatorKernelContext);
ORT_THROW_IF_FAILED(operatorKernelContext->PrepareSequenceOutput(outputIndex, dataType));
}

MLOperatorTensor GetSequenceOutputTensor(
uint32_t outputIndex,
uint32_t sequenceIndex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex
) const noexcept PURE;

//! Gets the number of dimensions of a tensor output of the operator.
STDMETHOD(GetSequenceInputCount)(
STDMETHOD(GetSequenceInputInfo)(
uint32_t inputIndex,
_Out_ uint32_t* inputCount
_Out_ uint32_t* inputCount,
MLOperatorTensorDataType* dataType
) const noexcept PURE;

//! Gets the number of dimensions of a tensor output of the operator.
Expand Down Expand Up @@ -187,9 +188,10 @@ interface DECLSPEC_UUID("440DA47C-018B-41F6-80A4-13FCF0544F37") DECLSPEC_NOVTABL
IMLOperatorTensorShapeDescriptionPrivate : IUnknown
{
//! Gets the number of dimensions of a tensor output of the operator.
STDMETHOD(GetSequenceInputCount)(
STDMETHOD(GetSequenceInputInfo)(
uint32_t inputIndex,
_Out_ uint32_t* inputCount
_Out_ uint32_t* inputCount,
MLOperatorTensorDataType* dataType
) const noexcept PURE;

//! Gets the number of dimensions of a tensor input of the operator.
Expand Down Expand Up @@ -225,6 +227,11 @@ IMLOperatorKernelContextPrivate : IUnknown
_COM_Outptr_result_maybenull_ IMLOperatorTensor** tensor
) const noexcept PURE;

//! Prepare the output tensor of the operator at the specified index.
STDMETHOD(PrepareSequenceOutput)(
uint32_t outputIndex,
MLOperatorTensorDataType dataType) const noexcept PURE;

//! Gets the output tensor of the operator at the specified index.
//! This sets tensor to nullptr for optional outputs which do not exist.
//! Returns an error if the output at the specified index is not a tensor.
Expand All @@ -241,9 +248,10 @@ IMLOperatorKernelContextPrivate : IUnknown
//! Gets the input tensor of the operator at the specified index.
//! This sets tensor to nullptr for optional inputs which do not exist.
//! Returns an error if the input at the specified index is not a tensor.
STDMETHOD(GetSequenceInputCount)(
STDMETHOD(GetSequenceInputInfo)(
uint32_t inputIndex,
_Out_ uint32_t* inputCount
_Out_ uint32_t* inputCount,
MLOperatorTensorDataType* dataType
) const noexcept PURE;

//! Returns whether the tensor at inputIndex is a sequence tensor or not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ struct IShapeInformationAdapter
virtual uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const = 0;
virtual std::vector<uint32_t> GetInputTensorShape(uint32_t inputIndex) const = 0;
virtual uint32_t GetSequenceInputCount(uint32_t inputIndex) const = 0;
virtual MLOperatorTensorDataType GetSequenceInputDataType(uint32_t inputIndex) const = 0;
virtual uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const = 0;
virtual std::vector<uint32_t> GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex) const = 0;
virtual ~IShapeInformationAdapter() {}
Expand Down Expand Up @@ -348,6 +349,7 @@ struct ShapeInformationAdapter : IShapeInformationAdapter
virtual uint32_t GetInputTensorDimensionCount(uint32_t inputIndex) const { return m_informationSource.GetInputTensorDimensionCount(inputIndex); }
virtual std::vector<uint32_t> GetInputTensorShape(uint32_t inputIndex) const { return m_informationSource.GetInputTensorShape(inputIndex); }
virtual uint32_t GetSequenceInputCount(uint32_t inputIndex) const { return m_informationSource.GetSequenceInputCount(inputIndex); }
virtual MLOperatorTensorDataType GetSequenceInputDataType(uint32_t inputIndex) const { return m_informationSource.GetSequenceInputDataType(inputIndex); }
virtual uint32_t GetSequenceInputTensorDimensionCount(uint32_t inputIndex, uint32_t sequenceIndex) const { return m_informationSource.GetSequenceInputTensorDimensionCount(inputIndex, sequenceIndex); }
virtual std::vector<uint32_t> GetSequenceInputTensorShape(uint32_t inputIndex, uint32_t sequenceIndex) const { return m_informationSource.GetSequenceInputTensorShape(inputIndex, sequenceIndex); }
virtual ~ShapeInformationAdapter() {}
Expand Down
4 changes: 2 additions & 2 deletions winml/lib/Api.Ort/OnnxruntimeModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ STDMETHODIMP OnnruntimeModel::JoinModel(_In_ IModel* other_model,
_In_ const char* const join_node_prefix) {
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
auto ort_api = engine_factory_->UseOrtApi();

RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->JoinModels(ort_model_.get(),
static_cast<OnnruntimeModel*>(other_model)->ort_model_.get(),
output_names,
Expand All @@ -385,4 +385,4 @@ STDMETHODIMP OnnruntimeModel::JoinModel(_In_ IModel* other_model,
// reset the info so that it is recreated with the new information lazily
info_ = nullptr;
return S_OK;
}
}

0 comments on commit 7ccdf9a

Please sign in to comment.