Skip to content

Commit

Permalink
Return distinct array of ParameterSet when ProposeSweep is called (do…
Browse files Browse the repository at this point in the history
…tnet#368)

* Changed List to HashSet to ensure that there are no duplicates
  • Loading branch information
ross-p-smith authored and eerhardt committed Jul 27, 2018
1 parent 89953cd commit c4a03af
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 4 deletions.
7 changes: 7 additions & 0 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "netstandard2.0", "netstanda
pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.targets = pkg\Microsoft.ML\build\netstandard2.0\Microsoft.ML.targets
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper.Tests", "test\Microsoft.ML.Sweeper.Tests\Microsoft.ML.Sweeper.Tests.csproj", "{3DEB504D-7A07-48CE-91A2-8047461CB3D4}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -216,6 +218,10 @@ Global
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -253,6 +259,7 @@ Global
{362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{487213C9-E8A9-4F94-85D7-28A05DBBFE3A} = {DEC8F776-49F7-4D87-836C-FE4DC057D08C}
{9252A8EB-ABFB-440C-AB4D-1D562753CE0F} = {487213C9-E8A9-4F94-85D7-28A05DBBFE3A}
{3DEB504D-7A07-48CE-91A2-8047461CB3D4} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.ML.Core/Prediction/ISweeper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ public override string ToString()
{
return string.Join(" ", _parameterValues.Select(kvp => string.Format("{0}={1}", kvp.Value.Name, kvp.Value.ValueText)).ToArray());
}

public override int GetHashCode()
{
return _hash;
}
}

/// <summary>
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Sweeper/Algorithms/Grid.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ protected SweeperBase(ArgumentsBase args, IHostEnvironment env, IValueGenerator[
SweepParameters = sweepParameters;
}

public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns)
public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
{
var prevParamSets = previousRuns?.Select(r => r.ParameterSet).ToList() ?? new List<ParameterSet>();
var result = new List<ParameterSet>();
var result = new HashSet<ParameterSet>();
for (int i = 0; i < maxSweeps; i++)
{
ParameterSet paramSet;
Expand Down Expand Up @@ -150,12 +150,12 @@ public RandomGridSweeper(IHostEnvironment env, Arguments args, IValueGenerator[]
}
}

public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns)
public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
{
if (_nGridPoints == 0)
return base.ProposeSweeps(maxSweeps, previousRuns);

var result = new List<ParameterSet>();
var result = new HashSet<ParameterSet>();
var prevParamSets = (previousRuns != null)
? previousRuns.Select(r => r.ParameterSet).ToList()
: new List<ParameterSet>();
Expand Down
12 changes: 12 additions & 0 deletions test/Microsoft.ML.Sweeper.Tests/Microsoft.ML.Sweeper.Tests.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netcoreapp2.0</TargetFramework>
<DefineConstants>CORECLR</DefineConstants>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
</ItemGroup>
</Project>
69 changes: 69 additions & 0 deletions test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Runtime.Sweeper;
using System;
using System.IO;
using Xunit;

namespace Microsoft.ML.Sweeper.Tests
{
public class SweeperTest
{
[Fact]
public void UniformRandomSweeperReturnsDistinctValuesWhenProposeSweep()
{
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();

using (var writer = new StreamWriter(new MemoryStream()))
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
{
var sweeper = new UniformRandomSweeper(env,
new SweeperBase.ArgumentsBase(),
new[] { valueGenerator });

var results = sweeper.ProposeSweeps(3);
Assert.NotNull(results);

int length = results.Length;
Assert.Equal(2, length);
}
}

[Fact]
public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep()
{
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();

using (var writer = new StreamWriter(new MemoryStream()))
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
{
var sweeper = new RandomGridSweeper(env,
new RandomGridSweeper.Arguments(),
new[] { valueGenerator });

var results = sweeper.ProposeSweeps(3);
Assert.NotNull(results);

int length = results.Length;
Assert.Equal(2, length);
}
}

private static DiscreteValueGenerator CreateDiscreteValueGenerator()
{
var args = new DiscreteParamArguments()
{
Name = "TestParam",
Values = new string[] { "one", "two" }
};

return new DiscreteValueGenerator(args);
}
}
}

0 comments on commit c4a03af

Please sign in to comment.