Skip to content

Commit c6bb688

Browse files
committed
Implement IDisposable
1 parent 76a7d0d commit c6bb688

File tree

3 files changed

+115
-19
lines changed

3 files changed

+115
-19
lines changed

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,25 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
179179
foreach (var col in options.OutputColumns)
180180
Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns));
181181

182-
// Use ONNXRuntime to figure out the right input ane output configuration.
182+
// Use ONNXRuntime to figure out the right input and output configuration.
183183
// However, ONNXRuntime doesn't provide strongly-typed method to access the produced
184184
// variables, we will inspect the ONNX model file to get information regarding types.
185185
try
186186
{
187187
if (modelBytes == null)
188188
{
189+
// Entering this region means that the model file is passed in by the user.
189190
Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
190191
Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
191-
Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu);
192+
// Because we cannot delete the user file, ownModelFile should be false.
193+
Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false);
192194
}
193195
else
196+
{
197+
// Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need
198+
// to create a temporal file to store it and then call ONNXRuntime's API to load that file.
194199
Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu);
200+
}
195201
}
196202
catch (OnnxRuntimeException e)
197203
{
@@ -283,7 +289,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
283289
ctx.CheckAtModel();
284290
ctx.SetVersionInfo(GetVersionInfo());
285291

286-
ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(Model.ToByteArray()); });
292+
ctx.SaveBinaryStream("OnnxModel", w => { w.WriteByteArray(File.ReadAllBytes(Model.ModelFile)); });
287293

288294
Host.CheckNonEmpty(Inputs, nameof(Inputs));
289295
ctx.Writer.Write(Inputs.Length);

src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace Microsoft.ML.Transforms.Onnx
2121
/// It provides API to open a session, score tensors (NamedOnnxValues) and return
2222
/// the results.
2323
/// </summary>
24-
internal sealed class OnnxModel
24+
internal sealed class OnnxModel : IDisposable
2525
{
2626
/// <summary>
2727
/// OnnxModelInfo contains the data that we should get from
@@ -112,18 +112,39 @@ public OnnxVariableInfo(string name, OnnxShape shape, System.Type ortType, DataV
112112
}
113113
}
114114

115-
public readonly OnnxModelInfo ModelInfo;
115+
/// <summary>
116+
/// The ONNXRuntime facility to execute the loaded ONNX model.
117+
/// </summary>
116118
private readonly InferenceSession _session;
117-
private readonly string _modelFile;
119+
/// <summary>
120+
/// Indicates if <see cref="ModelFile"/> is a temporal file created by <see cref="CreateFromBytes(byte[], int?, bool)"/>
121+
/// or <see cref="CreateFromBytes(byte[])"/>. If <see langword="true"/>, <see cref="Dispose(bool)"/> should delete <see cref="ModelFile"/>.
122+
/// </summary>
123+
private bool _ownModelFile;
124+
/// <summary>
125+
/// The ONNX model file that <see cref="OnnxModel"/> built upon.
126+
/// </summary>
127+
internal OnnxModelInfo ModelInfo { get; }
128+
/// <summary>
129+
/// The location where the used ONNX model loaded from.
130+
/// </summary>
131+
internal string ModelFile { get; }
118132

119133
/// <summary>
120134
/// Constructs OnnxModel object from file.
121135
/// </summary>
122136
/// <param name="modelFile">Model file path.</param>
123137
/// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
124138
/// <param name="fallbackToCpu">If true, resumes CPU execution quitely upon GPU error.</param>
125-
public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false)
139+
/// <param name="ownModelFile">If true, the <paramref name="modelFile"/> will be deleted when <see cref="OnnxModel"/> is
140+
/// no longer needed.</param>
141+
public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false, bool ownModelFile=false)
126142
{
143+
ModelFile = modelFile;
144+
// If we don't own the model file, _disposed should be false to prevent deleting user's file.
145+
_ownModelFile = ownModelFile;
146+
_disposed = false;
147+
127148
if (gpuDeviceId != null)
128149
{
129150
try
@@ -147,7 +168,7 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
147168

148169
// Load ONNX model file and parse its input and output schema. The reason of doing so is that ONNXRuntime
149170
// doesn't expose full type information via its C# APIs.
150-
_modelFile = modelFile;
171+
ModelFile = modelFile;
151172
var model = new OnnxCSharpToProtoWrapper.ModelProto();
152173
using (var modelStream = File.OpenRead(modelFile))
153174
model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(modelStream);
@@ -191,7 +212,9 @@ public OnnxModel(string modelFile, int? gpuDeviceId = null, bool fallbackToCpu =
191212
}
192213

193214
/// <summary>
194-
/// Create an OnnxModel from a byte[]
215+
/// Create an OnnxModel from a byte[]. Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
216+
/// With <see cref="CreateFromBytes(byte[])"/> and <see cref="CreateFromBytes(byte[], int?, bool)"/>, it's possible
217+
/// to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
195218
/// </summary>
196219
/// <param name="modelBytes">Bytes of the serialized model</param>
197220
/// <returns>OnnxModel</returns>
@@ -202,6 +225,9 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes)
202225

203226
/// <summary>
204227
/// Create an OnnxModel from a byte[]. Set execution to GPU if required.
228+
/// Usually, a ONNX model is consumed by <see cref="OnnxModel"/> as a file.
229+
/// With <see cref="CreateFromBytes(byte[])"/> and <see cref="CreateFromBytes(byte[], int?, bool)"/>,
230+
/// it's possible to use in-memory model (type: byte[]) to create <see cref="OnnxModel"/>.
205231
/// </summary>
206232
/// <param name="modelBytes">Bytes of the serialized model.</param>
207233
/// <param name="gpuDeviceId">GPU device ID to execute on. Null for CPU.</param>
@@ -214,12 +240,7 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu
214240

215241
var tempModelFile = Path.Combine(tempModelDir, "model.onnx");
216242
File.WriteAllBytes(tempModelFile, modelBytes);
217-
return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu);
218-
219-
// TODO:
220-
// tempModelFile is needed in case the model needs to be saved
221-
// Either have to save the modelbytes and delete the temp dir/file,
222-
// or keep the dir/file and write proper cleanup when application closes
243+
return new OnnxModel(tempModelFile, gpuDeviceId, fallbackToCpu, ownModelFile: true);
223244
}
224245

225246
/// <summary>
@@ -233,12 +254,42 @@ public IReadOnlyCollection<NamedOnnxValue> Run(List<NamedOnnxValue> inputNamedOn
233254
}
234255

235256
/// <summary>
236-
/// Convert the model to a byte array.
257+
/// Flag used to indicate if the unmanaged resources (aka the model file <see cref="ModelFile"/>
258+
/// and <see cref="_session"/>) have been deleted.
237259
/// </summary>
238-
/// <returns>byte[]</returns>
239-
public byte[] ToByteArray()
260+
private bool _disposed;
261+
262+
public void Dispose()
263+
{
264+
Dispose(true);
265+
GC.SuppressFinalize(this);
266+
}
267+
268+
/// <summary>
269+
/// There are two unmanaged resources we can dispose, <see cref="_session"/> and <see cref="ModelFile"/>
270+
/// if <see cref="_ownModelFile"/> is <see langword="true"/>.
271+
/// </summary>
272+
/// <param name="disposing"></param>
273+
private void Dispose(bool disposing)
274+
{
275+
if (!_disposed)
276+
{
277+
// There are two things to be disposed.
278+
if (disposing)
279+
{
280+
// First, we release the resource token by ONNXRuntime.
281+
_session.Dispose();
282+
// Second, we delete the model file if that file is not created by the user.
283+
if (_ownModelFile && File.Exists(ModelFile))
284+
File.Delete(ModelFile);
285+
}
286+
_disposed = true;
287+
}
288+
}
289+
290+
~OnnxModel()
240291
{
241-
return File.ReadAllBytes(_modelFile);
292+
Dispose(false);
242293
}
243294
}
244295

test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,5 +518,44 @@ public void TestOnnxZipMapWithStringKeys()
518518
Assert.Equal(dataPoints[i].Input[2], dictionary["C"]);
519519
}
520520
}
521+
522+
[Fact]
523+
public void TestOnnxModelDisposal()
524+
{
525+
// Create a ONNX model as a byte[].
526+
var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx");
527+
var modelInBytes = File.ReadAllBytes(modelFile);
528+
529+
// Create ONNX model from the byte[].
530+
var onnxModel = OnnxModel.CreateFromBytes(modelInBytes);
531+
532+
// Check if a temporal file is crated for storing the byte[].
533+
Assert.True(File.Exists(onnxModel.ModelFile));
534+
535+
// Delete the temporal file.
536+
onnxModel.Dispose();
537+
538+
// Make sure the temporal file is deleted.
539+
Assert.False(File.Exists(onnxModel.ModelFile));
540+
}
541+
542+
[Fact]
543+
public void TestOnnxModelNotDisposal()
544+
{
545+
// Declare the path the tested ONNX model file.
546+
var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapInt64.onnx");
547+
548+
// Create ONNX model from the model file.
549+
var onnxModel = new OnnxModel(modelFile);
550+
551+
// Check if a temporal file is crated for storing the byte[].
552+
Assert.True(File.Exists(onnxModel.ModelFile));
553+
554+
// Don't delete the temporal file!
555+
onnxModel.Dispose();
556+
557+
// Make sure the temporal file still exists.
558+
Assert.True(File.Exists(onnxModel.ModelFile));
559+
}
521560
}
522561
}

0 commit comments

Comments
 (0)