Skip to content

Commit

Permalink
Convert LdaEngine to a SafeHandle (#4538)
Browse files Browse the repository at this point in the history
  • Loading branch information
sharwell authored Dec 9, 2019
1 parent 30ebfc3 commit 8221dac
Showing 1 changed file with 34 additions and 25 deletions.
59 changes: 34 additions & 25 deletions src/Microsoft.ML.Transforms/Text/LdaSingleBox.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,83 +8,93 @@
using System.Runtime.InteropServices;
using System.Security;
using Microsoft.ML.Runtime;
using Microsoft.Win32.SafeHandles;

namespace Microsoft.ML.TextAnalytics
{

internal static class LdaInterface
{
public struct LdaEngine
public sealed class SafeLdaEngineHandle : SafeHandleZeroOrMinusOneIsInvalid
{
public IntPtr Ptr;
private SafeLdaEngineHandle()
: base(true)
{
}

protected override bool ReleaseHandle()
{
DestroyEngine(handle);
return true;
}
}

private const string NativePath = "LdaNative";
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern LdaEngine CreateEngine(int numTopic, int numVocab, float alphaSum, float beta, int numIter,
internal static extern SafeLdaEngineHandle CreateEngine(int numTopic, int numVocab, float alphaSum, float beta, int numIter,
int likelihoodInterval, int numThread, int mhstep, int maxDocToken);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void AllocateModelMemory(LdaEngine engine, int numTopic, int numVocab, long tableSize, long aliasTableSize);
internal static extern void AllocateModelMemory(SafeLdaEngineHandle engine, int numTopic, int numVocab, long tableSize, long aliasTableSize);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void AllocateDataMemory(LdaEngine engine, int docNum, long corpusSize);
internal static extern void AllocateDataMemory(SafeLdaEngineHandle engine, int docNum, long corpusSize);

[DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
internal static extern void Train(LdaEngine engine, string trainOutput);
internal static extern void Train(SafeLdaEngineHandle engine, string trainOutput);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetModelStat(LdaEngine engine, out long memBlockSize, out long aliasMemBlockSize);
internal static extern void GetModelStat(SafeLdaEngineHandle engine, out long memBlockSize, out long aliasMemBlockSize);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void Test(LdaEngine engine, int numBurninIter, float[] pLogLikelihood);
internal static extern void Test(SafeLdaEngineHandle engine, int numBurninIter, float[] pLogLikelihood);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void CleanData(LdaEngine engine);
internal static extern void CleanData(SafeLdaEngineHandle engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void CleanModel(LdaEngine engine);
internal static extern void CleanModel(SafeLdaEngineHandle engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void DestroyEngine(LdaEngine engine);
private static extern void DestroyEngine(IntPtr engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, ref int length);
internal static extern void GetWordTopic(SafeLdaEngineHandle engine, int wordId, int[] pTopic, int[] pProb, ref int length);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void SetWordTopic(LdaEngine engine, int wordId, int[] pTopic, int[] pProb, int length);
internal static extern void SetWordTopic(SafeLdaEngineHandle engine, int wordId, int[] pTopic, int[] pProb, int length);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void SetAlphaSum(LdaEngine engine, float avgDocLength);
internal static extern void SetAlphaSum(SafeLdaEngineHandle engine, float avgDocLength);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern int FeedInData(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int numVocab);
internal static extern int FeedInData(SafeLdaEngineHandle engine, int[] termId, int[] termFreq, int termNum, int numVocab);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern int FeedInDataDense(LdaEngine engine, int[] termFreq, int termNum, int numVocab);
internal static extern int FeedInDataDense(SafeLdaEngineHandle engine, int[] termFreq, int termNum, int numVocab);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetDocTopic(LdaEngine engine, int docId, int[] pTopic, int[] pProb, ref int numTopicReturn);
internal static extern void GetDocTopic(SafeLdaEngineHandle engine, int docId, int[] pTopic, int[] pProb, ref int numTopicReturn);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void GetTopicSummary(LdaEngine engine, int topicId, int[] pWords, float[] pProb, ref int numTopicReturn);
internal static extern void GetTopicSummary(SafeLdaEngineHandle engine, int topicId, int[] pWords, float[] pProb, ref int numTopicReturn);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void TestOneDoc(LdaEngine engine, int[] termId, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurnIter, bool reset);
internal static extern void TestOneDoc(SafeLdaEngineHandle engine, int[] termId, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurnIter, bool reset);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void TestOneDocDense(LdaEngine engine, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurninIter, bool reset);
internal static extern void TestOneDocDense(SafeLdaEngineHandle engine, int[] termFreq, int termNum, int[] pTopics, int[] pProbs, ref int numTopicsMax, int numBurninIter, bool reset);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void InitializeBeforeTrain(LdaEngine engine);
internal static extern void InitializeBeforeTrain(SafeLdaEngineHandle engine);

[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
internal static extern void InitializeBeforeTest(LdaEngine engine);
internal static extern void InitializeBeforeTest(SafeLdaEngineHandle engine);
}

internal sealed class LdaSingleBox : IDisposable
{
private LdaInterface.LdaEngine _engine;
private LdaInterface.SafeLdaEngineHandle _engine;
private bool _isDisposed;
private int[] _topics;
private int[] _probabilities;
Expand Down Expand Up @@ -358,8 +368,7 @@ public void Dispose()
if (_isDisposed)
return;
_isDisposed = true;
LdaInterface.DestroyEngine(_engine);
_engine.Ptr = IntPtr.Zero;
_engine.Dispose();
}
}
}

0 comments on commit 8221dac

Please sign in to comment.