Skip to content

Commit

Permalink
Update to Onnxruntime 1.2 and reenable its support for GPU (#4919)
Browse files Browse the repository at this point in the history
* Update Onnxruntime.managed version to 1.2

* Added dependencies to Onnxruntime (CPU) in test projects

* Reenabled GPU support
  • Loading branch information
antoniovs1029 authored Mar 10, 2020
1 parent 497708a commit c378772
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<GoogleProtobufPackageVersion>3.10.1</GoogleProtobufPackageVersion>
<LightGBMPackageVersion>2.2.3</LightGBMPackageVersion>
<MicrosoftExtensionsPackageVersion>2.1.0</MicrosoftExtensionsPackageVersion>
<MicrosoftMLOnnxRuntimePackageVersion>1.1.2</MicrosoftMLOnnxRuntimePackageVersion>
<MicrosoftMLOnnxRuntimePackageVersion>1.2</MicrosoftMLOnnxRuntimePackageVersion>
<MlNetMklDepsPackageVersion>0.0.0.9</MlNetMklDepsPackageVersion>
<ParquetDotNetPackageVersion>2.1.3</ParquetDotNetPackageVersion>
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.ML.Onnx.TestModels" Version="$(MicrosoftMLOnnxTestModelsVersion)" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="$(TensorFlowVersion)" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<ItemGroup>
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)"/>
<PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="$(MicrosoftMLOnnxRuntimePackageVersion)"/>
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
<PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
</ItemGroup>

Expand Down
16 changes: 13 additions & 3 deletions src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,19 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =

if (gpuDeviceId != null)
{
// The onnxruntime v1.0 currently does not support running on the GPU on all of ML.NET's supported platforms.
// This code path will be re-enabled when there is appropriate support in onnxruntime
throw new NotSupportedException("Running Onnx models on a GPU is temporarily not supported!");
try
{
_session = new InferenceSession(modelFile,
SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId.Value));
}
catch(OnnxRuntimeException)
{
if (fallbackToCpu)
_session = new InferenceSession(modelFile);
else
// If called from OnnxTransform, is caught and rethrown
throw;
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.ML.Onnx.TestModels" Version="$(MicrosoftMLOnnxTestModelsVersion)" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.ML.Onnx.TestModels" Version="$(MicrosoftMLOnnxTestModelsVersion)" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
1 change: 1 addition & 0 deletions test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
<PackageReference Include="Microsoft.ML.TestModels" Version="$(MicrosoftMLTestModelsPackageVersion)" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="$(TensorFlowVersion)" />
<PackageReference Include="System.Data.SqlClient" Version="$(SystemDataSqlClientVersion)" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="$(MicrosoftMLOnnxRuntimePackageVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down

0 comments on commit c378772

Please sign in to comment.