Skip to content

Commit

Permalink
Create SafeBoosterHandle and SafeDataSetHandle (#4539)
Browse files Browse the repository at this point in the history
  • Loading branch information
sharwell authored Dec 10, 2019
1 parent 8221dac commit 2e0000f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
10 changes: 4 additions & 6 deletions src/Microsoft.ML.LightGbm/WrappedLightGbmBooster.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ internal sealed class Booster : IDisposable
private readonly bool _hasValid;
private readonly bool _hasMetric;

public IntPtr Handle { get; private set; }
public WrappedLightGbmInterface.SafeBoosterHandle Handle { get; private set; }
public int BestIteration { get; set; }

public Booster(Dictionary<string, object> parameters, Dataset trainset, Dataset validset = null)
{
var param = LightGbmInterfaceUtils.JoinParameters(parameters);
var handle = IntPtr.Zero;
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, ref handle));
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterCreate(trainset.Handle, param, out var handle));
Handle = handle;
if (validset != null)
{
Expand Down Expand Up @@ -284,9 +283,8 @@ public InternalTreeEnsemble GetModel(int[] categoricalFeatureBoudaries)
#region IDisposable Support
public void Dispose()
{
if (Handle != IntPtr.Zero)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.BoosterFree(Handle));
Handle = IntPtr.Zero;
Handle?.Dispose();
Handle = null;
}
#endregion
}
Expand Down
19 changes: 8 additions & 11 deletions src/Microsoft.ML.LightGbm/WrappedLightGbmDataset.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace Microsoft.ML.Trainers.LightGbm
/// </summary>
internal sealed class Dataset : IDisposable
{
private IntPtr _handle;
private WrappedLightGbmInterface.SafeDataSetHandle _handle;
private int _lastPushedRowID;
public IntPtr Handle => _handle;
public WrappedLightGbmInterface.SafeDataSetHandle Handle => _handle;

/// <summary>
/// Create a <see cref="Dataset"/> for storing training and prediciton data under LightGBM framework. The main goal of this function
Expand Down Expand Up @@ -46,7 +46,7 @@ public unsafe Dataset(double[][] sampleValuePerColumn,
int numTotalRow,
string param, float[] labels, float[] weights = null, int[] groups = null)
{
_handle = IntPtr.Zero;
_handle = null;

// Use GCHandle to pin the memory, avoid the memory relocation.
GCHandle[] gcValues = new GCHandle[numCol];
Expand All @@ -68,7 +68,7 @@ public unsafe Dataset(double[][] sampleValuePerColumn,
// Create container. Examples will pushed in later.
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateFromSampledColumn(
(IntPtr)ptrValues, (IntPtr)ptrIndices, numCol, sampleNonZeroCntPerColumn, numSampleRow, numTotalRow,
param, ref _handle));
param, out _handle));
}
}
finally
Expand All @@ -92,11 +92,9 @@ public unsafe Dataset(double[][] sampleValuePerColumn,

public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weights = null, int[] groups = null)
{
IntPtr refHandle = IntPtr.Zero;
if (reference != null)
refHandle = reference.Handle;
WrappedLightGbmInterface.SafeDataSetHandle refHandle = reference?.Handle;

LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, ref _handle));
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetCreateByReference(refHandle, numTotalRow, out _handle));

SetLabel(labels);
SetWeights(weights);
Expand All @@ -105,9 +103,8 @@ public Dataset(Dataset reference, int numTotalRow, float[] labels, float[] weigh

public void Dispose()
{
if (_handle != IntPtr.Zero)
LightGbmInterfaceUtils.Check(WrappedLightGbmInterface.DatasetFree(_handle));
_handle = IntPtr.Zero;
_handle?.Dispose();
_handle = null;
}

/// <summary>
Expand Down
69 changes: 49 additions & 20 deletions src/Microsoft.ML.LightGbm/WrappedLightGbmInterface.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.ML.Runtime;
using Microsoft.Win32.SafeHandles;

namespace Microsoft.ML.Trainers.LightGbm
{
Expand Down Expand Up @@ -66,6 +67,20 @@ public static extern int FreeArray(

#region API Dataset

public sealed class SafeDataSetHandle : SafeHandleZeroOrMinusOneIsInvalid
{
private SafeDataSetHandle()
: base(true)
{
}

protected override bool ReleaseHandle()
{
LightGbmInterfaceUtils.Check(DatasetFree(handle));
return true;
}
}

[DllImport(DllName, EntryPoint = "LGBM_DatasetCreateFromSampledColumn", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetCreateFromSampledColumn(IntPtr sampleValuePerColumn,
IntPtr sampleIndicesPerColumn,
Expand All @@ -74,22 +89,22 @@ public static extern int DatasetCreateFromSampledColumn(IntPtr sampleValuePerCol
int numSampleRow,
int numTotalRow,
[MarshalAs(UnmanagedType.LPStr)]string parameters,
ref IntPtr ret);
out SafeDataSetHandle ret);

[DllImport(DllName, EntryPoint = "LGBM_DatasetCreateByReference", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetCreateByReference(IntPtr reference,
public static extern int DatasetCreateByReference(SafeDataSetHandle reference,
long numRow,
ref IntPtr ret);
out SafeDataSetHandle ret);

[DllImport(DllName, EntryPoint = "LGBM_DatasetPushRows", CallingConvention = CallingConvention.StdCall)]
private static extern int DatasetPushRows(IntPtr dataset,
private static extern int DatasetPushRows(SafeDataSetHandle dataset,
float[] data,
CApiDType dataType,
int numRow,
int numCol,
int startRowIdx);

public static int DatasetPushRows(IntPtr dataset,
public static int DatasetPushRows(SafeDataSetHandle dataset,
float[] data,
int numRow,
int numCol,
Expand All @@ -99,7 +114,7 @@ public static int DatasetPushRows(IntPtr dataset,
}

[DllImport(DllName, EntryPoint = "LGBM_DatasetPushRowsByCSR", CallingConvention = CallingConvention.StdCall)]
private static extern int DatasetPushRowsByCsr(IntPtr dataset,
private static extern int DatasetPushRowsByCsr(SafeDataSetHandle dataset,
int[] indPtr,
CApiDType indPtrType,
int[] indices,
Expand All @@ -110,7 +125,7 @@ private static extern int DatasetPushRowsByCsr(IntPtr dataset,
long numCol,
long startRowIdx);

public static int DatasetPushRowsByCsr(IntPtr dataset,
public static int DatasetPushRowsByCsr(SafeDataSetHandle dataset,
int[] indPtr,
int[] indices,
float[] data,
Expand All @@ -126,39 +141,53 @@ public static int DatasetPushRowsByCsr(IntPtr dataset,
}

[DllImport(DllName, EntryPoint = "LGBM_DatasetFree", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetFree(IntPtr handle);
private static extern int DatasetFree(IntPtr handle);

[DllImport(DllName, EntryPoint = "LGBM_DatasetSetField", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetSetField(
IntPtr handle,
SafeDataSetHandle handle,
[MarshalAs(UnmanagedType.LPStr)]string field,
IntPtr array,
int len,
CApiDType type);

[DllImport(DllName, EntryPoint = "LGBM_DatasetGetNumData", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetGetNumData(IntPtr handle, ref int res);
public static extern int DatasetGetNumData(SafeDataSetHandle handle, ref int res);

[DllImport(DllName, EntryPoint = "LGBM_DatasetGetNumFeature", CallingConvention = CallingConvention.StdCall)]
public static extern int DatasetGetNumFeature(IntPtr handle, ref int res);
public static extern int DatasetGetNumFeature(SafeDataSetHandle handle, ref int res);

#endregion

#region API Booster

public sealed class SafeBoosterHandle : SafeHandleZeroOrMinusOneIsInvalid
{
private SafeBoosterHandle()
: base(true)
{
}

protected override bool ReleaseHandle()
{
LightGbmInterfaceUtils.Check(BoosterFree(handle));
return true;
}
}

[DllImport(DllName, EntryPoint = "LGBM_BoosterCreate", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterCreate(IntPtr trainset,
public static extern int BoosterCreate(SafeDataSetHandle trainset,
[MarshalAs(UnmanagedType.LPStr)]string param,
ref IntPtr res);
out SafeBoosterHandle res);

[DllImport(DllName, EntryPoint = "LGBM_BoosterFree", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterFree(IntPtr handle);
private static extern int BoosterFree(IntPtr handle);

[DllImport(DllName, EntryPoint = "LGBM_BoosterAddValidData", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterAddValidData(IntPtr handle, IntPtr validset);
public static extern int BoosterAddValidData(SafeBoosterHandle handle, SafeDataSetHandle validset);

[DllImport(DllName, EntryPoint = "LGBM_BoosterSaveModelToString", CallingConvention = CallingConvention.StdCall)]
public static extern unsafe int BoosterSaveModelToString(IntPtr handle,
public static extern unsafe int BoosterSaveModelToString(SafeBoosterHandle handle,
int startIteration,
int numIteration,
int bufferLen,
Expand All @@ -170,20 +199,20 @@ public static extern unsafe int BoosterSaveModelToString(IntPtr handle,
#region API train

[DllImport(DllName, EntryPoint = "LGBM_BoosterUpdateOneIter", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterUpdateOneIter(IntPtr handle, ref int isFinished);
public static extern int BoosterUpdateOneIter(SafeBoosterHandle handle, ref int isFinished);

[DllImport(DllName, EntryPoint = "LGBM_BoosterGetEvalCounts", CallingConvention = CallingConvention.StdCall)]
public static extern int BoosterGetEvalCounts(IntPtr handle, ref int outLen);
public static extern int BoosterGetEvalCounts(SafeBoosterHandle handle, ref int outLen);

[DllImport(DllName, EntryPoint = "LGBM_BoosterGetEval", CallingConvention = CallingConvention.StdCall)]
public static extern unsafe int BoosterGetEval(IntPtr handle, int dataIdx,
public static extern unsafe int BoosterGetEval(SafeBoosterHandle handle, int dataIdx,
ref int outLen, double* outResult);

#endregion

#region API predict
[DllImport(DllName, EntryPoint = "LGBM_BoosterPredictForMat", CallingConvention = CallingConvention.StdCall)]
public static extern unsafe int BoosterPredictForMat(IntPtr handle, IntPtr data, CApiDType dataType, int nRow, int nCol, int isRowMajor,
public static extern unsafe int BoosterPredictForMat(SafeBoosterHandle handle, IntPtr data, CApiDType dataType, int nRow, int nCol, int isRowMajor,
int predictType, int numIteration, [MarshalAs(UnmanagedType.LPStr)]string parameters, ref int outLen, double* outResult);
#endregion

Expand Down

0 comments on commit 2e0000f

Please sign in to comment.