Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added IDisposable support for several classes #4939

Merged
merged 13 commits into from
Mar 24, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ public static void Example()
var idv = mlContext.Data.LoadFromEnumerable(data);

// Create a ML pipeline.
var pipeline = mlContext.Model.LoadTensorFlowModel(modelLocation)
.ScoreTensorFlowModel(
using var model = mlContext.Model.LoadTensorFlowModel(modelLocation);
var pipeline = model.ScoreTensorFlowModel(
new[] { nameof(OutputScores.output) },
new[] { nameof(TensorData.input) }, addBatchDimensionInput: true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public static void Example()
// Unfrozen (SavedModel format) models are loaded by providing the
// path to the directory containing the model file and other model
// artifacts like pre-trained weights.
var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(
using var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(
modelLocation);
var schema = tensorFlowModel.GetModelSchema();
var featuresType = (VectorDataViewType)schema["Features"].Type;
Expand Down
17 changes: 16 additions & 1 deletion src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;
Expand All @@ -15,7 +16,7 @@ namespace Microsoft.ML.Data
/// This class represents a data loader that applies a transformer chain after loading.
/// It also has methods to save itself to a repository.
/// </summary>
public sealed class CompositeDataLoader<TSource, TLastTransformer> : IDataLoader<TSource>
public sealed class CompositeDataLoader<TSource, TLastTransformer> : IDataLoader<TSource>, IDisposable
Copy link
Member

@sharwell sharwell Mar 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 By making the class disposable when IDataLoader<TSource> is not disposable, it will be easy for a consumer to fail to dispose of the object. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't change interfaces now that they are shipped. Right now, when customers discover they have a memory leak, we don't have a solution. With this method at least we can tell them to add a Dispose call.

Is there a better approach we can consider?


In reply to: 396132696 [](ancestors = 396132696)

Copy link
Member

@sharwell sharwell Mar 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every case that I reviewed locally required breaking API changes. Default interface methods might be an option for .NET Core targets, but those are completely unsupported on .NET Framework. #Resolved

Copy link
Contributor

@justinormont justinormont Mar 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the component GA? Or can we still make a breaking change? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All public interfaces are frozen. This would be one of those.


In reply to: 396753025 [](ancestors = 396753025)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All public interfaces are frozen. This would be one of those.

There may be a difference between interfaces released as GA and the preview releases. I'm not sure which one this falls under.

/cc @eerhardt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDataLoader is a core interface. It is part of the GA set.


In reply to: 396760924 [](ancestors = 396760924)

where TLastTransformer : class, ITransformer
{
internal const string TransformerDirectory = TransformerChain.LoaderSignature;
Expand Down Expand Up @@ -110,5 +111,19 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(CompositeDataLoader<,>).Assembly.FullName);
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

Transformer.Dispose();

_disposed = true;
}
#endregion
}
}
17 changes: 16 additions & 1 deletion src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ internal interface ITransformerChainAccessor
/// A chain of transformers (possibly empty) that end with a <typeparamref name="TLastTransformer"/>.
/// For an empty chain, <typeparamref name="TLastTransformer"/> is always <see cref="ITransformer"/>.
/// </summary>
public sealed class TransformerChain<TLastTransformer> : ITransformer, IEnumerable<ITransformer>, ITransformerChainAccessor
public sealed class TransformerChain<TLastTransformer> : ITransformer, IEnumerable<ITransformer>, ITransformerChainAccessor, IDisposable
where TLastTransformer : class, ITransformer
{
private readonly ITransformer[] _transformers;
Expand Down Expand Up @@ -232,6 +232,21 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
}
return new CompositeRowToRowMapper(inputSchema, mappers);
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

foreach (var transformer in _transformers)
(transformer as IDisposable)?.Dispose();

_disposed = true;
}
#endregion
}

/// <summary>
Expand Down
17 changes: 16 additions & 1 deletion src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.ML.Data
/// A row-to-row mapper that is the result of a chained application of multiple mappers.
/// </summary>
[BestFriend]
internal sealed class CompositeRowToRowMapper : IRowToRowMapper
internal sealed class CompositeRowToRowMapper : IRowToRowMapper, IDisposable
{
[BestFriend]
internal IRowToRowMapper[] InnerMappers { get; }
Expand Down Expand Up @@ -118,5 +118,20 @@ public SubsetActive(DataViewRow row, Func<int, bool> pred)
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column) => _pred(column.Index);
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

foreach (var mapper in InnerMappers)
(mapper as IDisposable)?.Dispose();

_disposed = true;
}
#endregion
}
}
13 changes: 4 additions & 9 deletions src/Microsoft.ML.Data/Prediction/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,13 @@ private protected virtual Func<DataViewSchema, IRowToRowMapper> TransformerCheck
}

public void Dispose()
{
Disposing(true);
GC.SuppressFinalize(this);
}

[BestFriend]
private protected void Disposing(bool disposing)
{
if (_disposed)
return;
if (disposing)
_disposer?.Invoke();

_disposer?.Invoke();
(Transformer as IDisposable)?.Dispose();

_disposed = true;
}

Expand Down
18 changes: 17 additions & 1 deletion src/Microsoft.ML.Data/Scorers/MulticlassClassificationScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ private static VersionInfo GetVersionInfo()
/// </summary>
// REVIEW: It seems like the attachment of metadata should be solvable in a manner
// less ridiculously verbose than this.
public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveModel, IBindableCanSavePfa, IBindableCanSaveOnnx
public sealed class LabelNameBindableMapper : ISchemaBindableMapper, ICanSaveModel, IBindableCanSavePfa,
IBindableCanSaveOnnx, IDisposable
{
private static readonly FuncInstanceMethodInfo1<LabelNameBindableMapper, object, Delegate> _decodeInitMethodInfo
= FuncInstanceMethodInfo1<LabelNameBindableMapper, object, Delegate>.Create(target => target.DecodeInit<int>);
Expand Down Expand Up @@ -379,6 +380,21 @@ public RowImpl(DataViewRow row, DataViewSchema schema)
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) => Input.GetGetter<TValue>(column);
}
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
// TODO: Is it necessary to call the base class Dispose()?
if (_disposed)
return;

(_bindable as IDisposable)?.Dispose();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISchemaBindableMapper is internal, so if you want you can make it inherit from IDisposable without a breaking change.


_disposed = true;
}
#endregion
}

/// <summary>
Expand Down
17 changes: 16 additions & 1 deletion src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.ML.Data
/// Class for scorers that compute on additional "PredictedLabel" column from the score column.
/// Currently, this scorer is used for binary classification, multi-class classification, and clustering.
/// </summary>
internal abstract class PredictedLabelScorerBase : RowToRowScorerBase, ITransformCanSavePfa, ITransformCanSaveOnnx
internal abstract class PredictedLabelScorerBase : RowToRowScorerBase, ITransformCanSavePfa, ITransformCanSaveOnnx, IDisposable
{
public abstract class ThresholdArgumentsBase : ScorerArgumentsBase
{
Expand Down Expand Up @@ -435,5 +435,20 @@ protected void EnsureCachedPosition<TScore>(ref long cachedPosition, ref TScore
cachedPosition = boundRow.Position;
}
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

(Bindings.RowMapper as IDisposable)?.Dispose();
(Bindable as IDisposable)?.Dispose();

_disposed = true;
}
#endregion
}
}
18 changes: 17 additions & 1 deletion src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ internal static class PredictionTransformerBase
/// Base class for transformers with no feature column, or more than one feature columns.
/// </summary>
/// <typeparam name="TModel">The type of the model parameters used by this prediction transformer.</typeparam>
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>
public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer<TModel>, IDisposable
where TModel : class
{
/// <summary>
Expand Down Expand Up @@ -181,6 +181,22 @@ private protected void SaveModelCore(ModelSaveContext ctx)
}
});
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

(Model as IDisposable)?.Dispose();
(BindableMapper as IDisposable)?.Dispose();
(Scorer as IDisposable)?.Dispose();

_disposed = true;
}
#endregion
}

/// <summary>
Expand Down
32 changes: 30 additions & 2 deletions src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace Microsoft.ML.Data
/// This is a base class for wrapping <see cref="IPredictor"/>s in an <see cref="ISchemaBindableMapper"/>.
/// </summary>
internal abstract class SchemaBindablePredictorWrapperBase : ISchemaBindableMapper, ICanSaveModel, ICanSaveSummary,
IBindableCanSavePfa, IBindableCanSaveOnnx
IBindableCanSavePfa, IBindableCanSaveOnnx, IDisposable
{
// The ctor guarantees that Predictor is non-null. It also ensures that either
// ValueMapper or FloatPredictor is non-null (or both). With these guarantees,
Expand Down Expand Up @@ -193,7 +193,7 @@ void ICanSaveSummary.SaveSummary(TextWriter writer, RoleMappedSchema schema)
/// This class doesn't care. It DOES care that the role mapped schema specifies a unique Feature column.
/// It also requires that the output schema has ColumnCount == 1.
/// </summary>
protected sealed class SingleValueRowMapper : ISchemaBoundRowMapper
protected sealed class SingleValueRowMapper : ISchemaBoundRowMapper, IDisposable
{
private readonly SchemaBindablePredictorWrapperBase _parent;

Expand Down Expand Up @@ -241,7 +241,35 @@ DataViewRow ISchemaBoundRowMapper.GetRow(DataViewRow input, IEnumerable<DataView
getters[0] = _parent.GetPredictionGetter(input, InputRoleMappedSchema.Feature.Value.Index);
return new SimpleRow(OutputSchema, input, getters);
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

(_parent as IDisposable)?.Dispose();

_disposed = true;
}
#endregion
}

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

(Predictor as IDisposable)?.Dispose();

_disposed = true;
}
#endregion
}

/// <summary>
Expand Down
18 changes: 17 additions & 1 deletion src/Microsoft.ML.TensorFlow/TensorFlowModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.TensorFlow;
using Tensorflow;
Expand All @@ -13,7 +14,7 @@ namespace Microsoft.ML.Transforms
/// It provides some convenient methods to query model schema as well as
/// creation of <see cref="TensorFlowEstimator"/> object.
/// </summary>
public sealed class TensorFlowModel
public sealed class TensorFlowModel : IDisposable
{
internal Session Session { get; }
internal string ModelPath { get; }
Expand All @@ -31,6 +32,7 @@ internal TensorFlowModel(IHostEnvironment env, Session session, string modelLoca
Session = session;
ModelPath = modelLocation;
_env = env;
_disposed = false;
}

/// <summary>
Expand Down Expand Up @@ -83,5 +85,19 @@ public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string
/// </example>
public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames, bool addBatchDimensionInput = false)
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, this, addBatchDimensionInput);

#region IDisposable Support
private bool _disposed;

public void Dispose()
{
if (_disposed)
return;

Session.Dispose();

_disposed = true;
}
#endregion
}
}
3 changes: 3 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ public static class TensorflowCatalog
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlowModel.ScoreTensorFlowModel(string, string, bool)"/>.
/// usage of this API requires additional NuGet dependencies on TensorFlow redist, see linked document for more information.
/// <see cref="TensorFlowModel"/> also holds references to unmanaged resources that need to be freed either with an explicit
/// call to Dispose() or implicitly by declaring the variable with the "using" syntax/>
///
/// <format type="text/markdown">
/// <![CDATA[
/// [!include[io](~/../docs/samples/docs/api-reference/tensorflow-usage.md)]
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, Graph grap
/// <param name="modelPath">Model to load.</param>
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
{
var model = LoadTensorFlowModel(env, modelPath);
using var model = LoadTensorFlowModel(env, modelPath);
return GetModelSchema(env, model.Session.graph);
}

Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Vision/DnnRetrainTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ public void Dispose()
{
if (_session != null && _session != IntPtr.Zero)
{
if (_session.graph != null)
_session.graph.Dispose();
_session.close();
}
}
Expand Down
Loading