Skip to content

Commit 2ed64b5

Browse files
george-microsofteerhardt
authored andcommitted
RocketEngine fix for selecting top learners (dotnet#270)
* Changes to RocketEngine to fix take top k logic. * Add namespace information to allow file to reference correct version of Formatting object.
1 parent aa2f5b4 commit 2ed64b5

File tree

2 files changed

+5
-16
lines changed

2 files changed

+5
-16
lines changed

src/Microsoft.ML.PipelineInference/AutoMlEngines/RocketEngine.cs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -196,21 +196,10 @@ private TransformInference.SuggestedTransform[] SampleTransforms(RecipeInference
196196
private RecipeInference.SuggestedRecipe.SuggestedLearner[] GetTopLearners(IEnumerable<PipelinePattern> history)
197197
{
198198
var weights = LearnerHistoryToWeights(history.ToArray(), IsMaximizingMetric);
199-
var topKTuples = new Tuple<double, int>[_topK];
200-
201-
for (int i = 0; i < weights.Length; i++)
202-
{
203-
if (i < _topK)
204-
topKTuples[i] = new Tuple<double, int>(weights[i], i);
205-
else
206-
{
207-
for (int j = 0; j < topKTuples.Length; j++)
208-
if (weights[i] > topKTuples[j].Item1)
209-
topKTuples[j] = new Tuple<double, int>(weights[i], i);
210-
}
211-
}
212-
213-
return topKTuples.Select(t => AvailableLearners[t.Item2]).ToArray();
199+
return weights.Select((w, i) => new { Weight = w, Index = i })
200+
.OrderByDescending(x => x.Weight)
201+
.Take(_topK)
202+
.Select(t=>AvailableLearners[t.Index]).ToArray();
214203
}
215204

216205
public override PipelinePattern[] GetNextCandidates(IEnumerable<PipelinePattern> history, int numCandidates)

src/Microsoft.ML.PipelineInference/DatasetFeaturesInference.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ public static string InferDatasetFeatures(IHostEnvironment env, Arguments args)
507507
}
508508

509509
if (args.PrettyPrint)
510-
jsonString = JsonConvert.SerializeObject(features, Formatting.Indented);
510+
jsonString = JsonConvert.SerializeObject(features, Newtonsoft.Json.Formatting.Indented);
511511
else
512512
jsonString = JsonConvert.SerializeObject(features);
513513

0 commit comments

Comments
 (0)