Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create SafeBoosterHandle and SafeDataSetHandle #4539

Merged
merged 1 commit into from
Dec 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/Microsoft.ML.LightGbm/WrappedLightGbmBooster.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ internal sealed class Booster : IDisposable
private readonly bool _hasValid;
private readonly bool _hasMetric;

public IntPtr Handle { get; private set; }
public WrappedLightGbmInterface.SafeBoosterHandle Handle { get; private set; }
public int BestIteration { get; set; }

public Booster(Dictionary<string, object> parameters, Dataset trainset, Dataset validset = null)
{
var param = LightGbmInterfaceUtils.JoinParameters(parameters);
var handle = IntPtr.Zero;
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, ref handle));
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, out var handle));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 The change to SafeHandle means that the finalizer for SafeBoosterHandle will now be invoked at some point if the call to BoosterCreate succeeds but an exception is thrown later in this constructor (i.e. an exception prevents the caller from being able to use Booster.Dispose()). Ideally, code which constructs disposable objects should be robust in ensuring created objects are disposed if a failure would prevent those objects from being disposed by the caller. I'm not ready to submit the change to ensure all the callers dispose of objects deterministically on exceptional paths, but I'd like to merge this one in the meantime.

Copy link
Member

@eerhardt eerhardt Dec 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing a quick scan of the usages of the Booster and Dataset classes, they are always disposed of in the product code (one in a using and the other in a try-finally. The only case I saw where they aren't disposed of is in the tests.

var gbmNative = WrappedLightGbmTraining.Train(ch, pch, gbmParams, gbmDataSet, numIteration: numberOfTrainingIterations);

var gbmDataSet = new Trainers.LightGbm.Dataset(sampleValueGroupedByColumn, sampleIndicesGroupedByColumn, _columnNumber, sampleNonZeroCntPerColumn, _rowNumber, _rowNumber, "", floatLabels);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Booster class is disposed if the constructor returns without throwing an exception. However, if LightGbmInterfaceUtils.Check throws an exception from the constructor itself, the SafeBoosterHandle will have been created by the call to BoosterCreate but never stored in an object that is later disposed. Prior to the use of a SafeHandle, an instance in this scenario would never be cleaned up. Now it will be cleaned up by the finalizer.

Handle = handle;
if (validset != null)
{
Expand Down Expand Up @@ -284,9 +283,8 @@ public InternalTreeEnsemble GetModel(int[] categoricalFeatureBoudaries)
#region IDisposable Support
public void Dispose()
{
if (Handle != IntPtr.Zero)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterFree(Handle));
Handle = IntPtr.Zero;
Handle?.Dispose();
Handle = null;
}
#endregion
}
Expand Down
19 changes: 8 additions & 11 deletions src/Microsoft.ML.LightGbm/WrappedLightGbmDataset.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace Microsoft.ML.Trainers.LightGbm
/// </summary>
internal sealed class Dataset : IDisposable
{
private IntPtr _handle;
private WrappedLightGbmInterface.SafeDataSetHandle _handle;
private int _lastPushedRowID;
public IntPtr Handle => _handle;
public WrappedLightGbmInterface.SafeDataSetHandle Handle => _handle;

/// <summary>
/// Create a <see cref="Dataset"/> for storing training and prediciton data under LightGBM framework. The main goal of this function
Expand Down Expand Up @@ -46,7 +46,7 @@ public unsafe Dataset(double[][] sampleValuePerColumn,
int numTotalRow,
string param, float[] labels, float[] weights = null, int[] groups = null)
{
_handle = IntPtr.Zero;
_handle = null;

// Use GCHandle to pin the memory, avoid the memory relocation.
GCHandle[] gcValues = new GCHandle[numCol];
Expand All @@ -68,7 +68,7 @@ public unsafe Dataset(double[][] sampleValuePerColumn,
// Create container. Examples will pushed in later.
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateFromSampledColumn(
(IntPtr)ptrValues, (IntPtr)ptrIndices, numCol, sampleNonZeroCntPerColumn, numSampleRow, numTotalRow,
param, ref _handle));
param, out _handle));
}
}
finally
Expand All @@ -92,11 +92,9 @@ public unsafe Dataset(double[][] sampleValuePerColumn,

public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weights = null, int[] groups = null)
{
IntPtr refHandle = IntPtr.Zero;
if (reference != null)
refHandle = reference.Handle;
WrappedLightGbmInterface.SafeDataSetHandle refHandle = reference?.Handle;

LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, ref _handle));
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, out _handle));

SetLabel(labels);
SetWeights(weights);
Expand All @@ -105,9 +103,8 @@ public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weigh

public void Dispose()
{
if (_handle != IntPtr.Zero)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetFree(_handle));
_handle = IntPtr.Zero;
_handle?.Dispose();
_handle = null;
}

/// <summary>
Expand Down
69 changes: 49 additions & 20 deletions src/Microsoft.ML.LightGbm/WrappedLightGbmInterface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.ML.Runtime;
using Microsoft.Win32.SafeHandles;

namespace Microsoft.ML.Trainers.LightGbm
{
Expand Down Expand Up @@ -66,6 +67,20 @@ public static extern int FreeArray(

#region API Dataset

public sealed class SafeDataSetHandle : SafeHandleZeroOrMinusOneIsInvalid
{
private SafeDataSetHandle()
: base(true)
{
}

protected override bool ReleaseHandle()
{
LightGbmInterfaceUtils.Check(DatasetFree(handle));
return true;
}
}

[DllImport(DllName, EntryPoint = "LGBM_DatasetCreateFromSampledColumn", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetCreateFromSampledColumn(IntPtr sampleValuePerColumn,
IntPtr sampleIndicesPerColumn,
Expand All @@ -74,22 +89,22 @@ public static extern int DatasetCreateFromSampledColumn(IntPtr sampleValuePerCol
int numSampleRow,
int numTotalRow,
[MarshalAs(UnmanagedType.LPStr)]string parameters,
ref IntPtr ret);
out SafeDataSetHandle ret);

[DllImport(DllName, EntryPoint = "LGBM_DatasetCreateByReference", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetCreateByReference(IntPtr reference,
public static extern int DatasetCreateByReference(SafeDataSetHandle reference,
long numRow,
ref IntPtr ret);
out SafeDataSetHandle ret);

[DllImport(DllName, EntryPoint = "LGBM_DatasetPushRows", CallingConvention = CallingConvention.StdCall)]
private static extern int DatasetPushRows(IntPtr dataset,
private static extern int DatasetPushRows(SafeDataSetHandle dataset,
float[] data,
CApiDType dataType,
int numRow,
int numCol,
int startRowIdx);

public static int DatasetPushRows(IntPtr dataset,
public static int DatasetPushRows(SafeDataSetHandle dataset,
float[] data,
int numRow,
int numCol,
Expand All @@ -99,7 +114,7 @@ public static int DatasetPushRows(IntPtr dataset,
}

[DllImport(DllName, EntryPoint = "LGBM_DatasetPushRowsByCSR", CallingConvention = CallingConvention.StdCall)]
private static extern int DatasetPushRowsByCsr(IntPtr dataset,
private static extern int DatasetPushRowsByCsr(SafeDataSetHandle dataset,
int[] indPtr,
CApiDType indPtrType,
int[] indices,
Expand All @@ -110,7 +125,7 @@ private static extern int DatasetPushRowsByCsr(IntPtr dataset,
long numCol,
long startRowIdx);

public static int DatasetPushRowsByCsr(IntPtr dataset,
public static int DatasetPushRowsByCsr(SafeDataSetHandle dataset,
int[] indPtr,
int[] indices,
float[] data,
Expand All @@ -126,39 +141,53 @@ public static int DatasetPushRowsByCsr(IntPtr dataset,
}

[DllImport(DllName, EntryPoint = "LGBM_DatasetFree", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetFree(IntPtr handle);
private static extern int DatasetFree(IntPtr handle);

[DllImport(DllName, EntryPoint = "LGBM_DatasetSetField", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetSetField(
IntPtr handle,
SafeDataSetHandle handle,
[MarshalAs(UnmanagedType.LPStr)]string field,
IntPtr array,
int len,
CApiDType type);

[DllImport(DllName, EntryPoint = "LGBM_DatasetGetNumData", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetGetNumData(IntPtr handle, ref int res);
public static extern int DatasetGetNumData(SafeDataSetHandle handle, ref int res);

[DllImport(DllName, EntryPoint = "LGBM_DatasetGetNumFeature", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetGetNumFeature(IntPtr handle, ref int res);
public static extern int DatasetGetNumFeature(SafeDataSetHandle handle, ref int res);

#endregion

#region API Booster

public sealed class SafeBoosterHandle : SafeHandleZeroOrMinusOneIsInvalid
{
private SafeBoosterHandle()
: base(true)
{
}

protected override bool ReleaseHandle()
{
LightGbmInterfaceUtils.Check(BoosterFree(handle));
return true;
}
}

[DllImport(DllName, EntryPoint = "LGBM_BoosterCreate", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterCreate(IntPtr trainset,
public static extern int BoosterCreate(SafeDataSetHandle trainset,
[MarshalAs(UnmanagedType.LPStr)]string param,
ref IntPtr res);
out SafeBoosterHandle res);

[DllImport(DllName, EntryPoint = "LGBM_BoosterFree", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterFree(IntPtr handle);
private static extern int BoosterFree(IntPtr handle);

[DllImport(DllName, EntryPoint = "LGBM_BoosterAddValidData", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterAddValidData(IntPtr handle, IntPtr validset);
public static extern int BoosterAddValidData(SafeBoosterHandle handle, SafeDataSetHandle validset);

[DllImport(DllName, EntryPoint = "LGBM_BoosterSaveModelToString", CallingConvention = CallingConvention.StdCall)]
public static extern unsafe int BoosterSaveModelToString(IntPtr handle,
public static extern unsafe int BoosterSaveModelToString(SafeBoosterHandle handle,
int startIteration,
int numIteration,
int bufferLen,
Expand All @@ -170,20 +199,20 @@ public static extern unsafe int BoosterSaveModelToString(IntPtr handle,
#region API train

[DllImport(DllName, EntryPoint = "LGBM_BoosterUpdateOneIter", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterUpdateOneIter(IntPtr handle, ref int isFinished);
public static extern int BoosterUpdateOneIter(SafeBoosterHandle handle, ref int isFinished);

[DllImport(DllName, EntryPoint = "LGBM_BoosterGetEvalCounts", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterGetEvalCounts(IntPtr handle, ref int outLen);
public static extern int BoosterGetEvalCounts(SafeBoosterHandle handle, ref int outLen);

[DllImport(DllName, EntryPoint = "LGBM_BoosterGetEval", CallingConvention = CallingConvention.StdCall)]
public static extern unsafe int BoosterGetEval(IntPtr handle, int dataIdx,
public static extern unsafe int BoosterGetEval(SafeBoosterHandle handle, int dataIdx,
ref int outLen, double* outResult);

#endregion

#region API predict
[DllImport(DllName, EntryPoint = "LGBM_BoosterPredictForMat", CallingConvention = CallingConvention.StdCall)]
public static extern unsafe int BoosterPredictForMat(IntPtr handle, IntPtr data, CApiDType dataType, int nRow, int nCol, int isRowMajor,
public static extern unsafe int BoosterPredictForMat(SafeBoosterHandle handle, IntPtr data, CApiDType dataType, int nRow, int nCol, int isRowMajor,
int predictType, int numIteration, [MarshalAs(UnmanagedType.LPStr)]string parameters, ref int outLen, double* outResult);
#endregion

Expand Down