Skip to content

Commit

Permalink
fixes one dal dispatching issues (#6547)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelgsharp authored Jan 24, 2023
1 parent a06dadc commit 9181467
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 45 deletions.
20 changes: 20 additions & 0 deletions src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Reflection;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime;

Expand Down Expand Up @@ -171,5 +172,24 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)

[BestFriend]
internal void CancelExecution() => ((ICancelable)_env).CancelExecution();

[BestFriend]
internal static readonly bool OneDalDispatchingEnabled = InitializeOneDalDispatchingEnabled();

private static bool InitializeOneDalDispatchingEnabled()
{
try
{
var asm = Assembly.Load("Microsoft.ML.OneDal");
var type = asm.GetType("Microsoft.ML.OneDal.OneDalUtils");
var method = type.GetMethod("IsDispatchingEnabled", BindingFlags.Public | BindingFlags.Static | BindingFlags.NonPublic);
var result = method.Invoke(null, null);
return (bool)result;
}
catch
{
return false;
}
}
}
}
16 changes: 1 addition & 15 deletions src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ private protected override FastForestBinaryModelParameters TrainModelCore(TrainC
FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
ConvertData(trainData);

if (!trainData.Schema.Weight.HasValue && IsDispatchingToOneDalEnabled())
if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled)
{
if (FastTreeTrainerOptions.FeatureFraction != 1.0)
{
Expand Down Expand Up @@ -262,20 +262,6 @@ public static extern unsafe int DecisionForestClassificationCompute(
void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr);
}

[BestFriend]
private bool IsDispatchingToOneDalEnabled()
{
try
{
return OneDalUtils.IsDispatchingEnabled();
}
catch (Exception)
{
// Bail to default implementation upon encountering any situation where dispatch failed
return false;
}
}

[BestFriend]
private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
{
Expand Down
16 changes: 1 addition & 15 deletions src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ private protected override FastForestRegressionModelParameters TrainModelCore(Tr
FeatureCount = trainData.Schema.Feature.Value.Type.GetValueCount();
ConvertData(trainData);

if (!trainData.Schema.Weight.HasValue && IsDispatchingToOneDalEnabled())
if (!trainData.Schema.Weight.HasValue && MLContext.OneDalDispatchingEnabled)
{
if (FastTreeTrainerOptions.FeatureFraction != 1.0)
{
Expand Down Expand Up @@ -395,20 +395,6 @@ public static extern unsafe int DecisionForestRegressionCompute(
void* lteChildPtr, void* gtChildPtr, void* splitFeaturePtr, void* featureThresholdPtr, void* leafValuesPtr, void* modelPtr);
}

[BestFriend]
private bool IsDispatchingToOneDalEnabled()
{
try
{
return OneDalUtils.IsDispatchingEnabled();
}
catch (Exception)
{
// fall back to original implementation for any circumstance that prevents dispatching
return false;
}
}

[BestFriend]
private void TrainCoreOneDal(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
{
Expand Down
16 changes: 1 addition & 15 deletions src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -406,20 +406,6 @@ private void ComputeMklRegression(IChannel ch, FloatLabelCursor.Factory cursorFa
xty = null;
}

[BestFriend]
private bool IsDispatchingToOneDalEnabled()
{
try
{
return OneDalUtils.IsDispatchingEnabled();
}
catch (Exception)
{
// Bail to default implementation upon any situation that prevents dispatching
return false;
}
}

private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)
{
Host.AssertValue(ch);
Expand All @@ -440,7 +426,7 @@ private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory curso
var beta = new Double[m];
Double yMean = 0;

if (IsDispatchingToOneDalEnabled())
if (MLContext.OneDalDispatchingEnabled)
{
ComputeOneDalRegression(ch, cursorFactory, m, ref beta, xtx, ref n, ref yMean);
}
Expand Down

0 comments on commit 9181467

Please sign in to comment.