Skip to content

Propose sweep #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Mi
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Console\Microsoft.ML.Console.csproj", "{362A98CF-FBF7-4EBB-A11B-990BBF845B15}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper.Tests", "test\Microsoft.ML.Sweeper.Tests\Microsoft.ML.Sweeper.Tests.csproj", "{3DEB504D-7A07-48CE-91A2-8047461CB3D4}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -208,6 +210,10 @@ Global
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Debug|Any CPU.Build.0 = Debug|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3DEB504D-7A07-48CE-91A2-8047461CB3D4}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -243,6 +249,7 @@ Global
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{3DEB504D-7A07-48CE-91A2-8047461CB3D4} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
12 changes: 6 additions & 6 deletions src/Microsoft.ML.Sweeper/Algorithms/Grid.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected SweeperBase(ArgumentsBase args, IHostEnvironment env, IValueGenerator[
SweepParameters = sweepParameters;
}

public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns)
public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
{
var prevParamSets = previousRuns?.Select(r => r.ParameterSet).ToList() ?? new List<ParameterSet>();
var result = new List<ParameterSet>();
Expand All @@ -80,12 +80,11 @@ public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResul
(AlreadyGenerated(paramSet, prevParamSets) || AlreadyGenerated(paramSet, result)));

Contracts.Assert(paramSet != null);
result.Add(paramSet);
if (!result.Contains(paramSet))
result.Add(paramSet);
}

return result.ToArray();
}

protected abstract ParameterSet CreateParamSet();

protected static bool AlreadyGenerated(ParameterSet paramSet, IEnumerable<ParameterSet> previousRuns)
Expand Down Expand Up @@ -150,7 +149,7 @@ public RandomGridSweeper(IHostEnvironment env, Arguments args, IValueGenerator[]
}
}

public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns)
public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
{
if (_nGridPoints == 0)
return base.ProposeSweeps(maxSweeps, previousRuns);
Expand All @@ -173,7 +172,8 @@ public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResu
if (!AlreadyGenerated(_cache[iPerm], prevParamSets))
break;
}
result.Add(_cache[iPerm]);
if (!result.Contains(_cache[iPerm]))
result.Add(_cache[iPerm]);
}
return result.ToArray();
}
Expand Down
22 changes: 22 additions & 0 deletions test/Microsoft.ML.Sweeper.Tests/Microsoft.ML.Sweeper.Tests.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netcoreapp2.0</TargetFramework>

<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.7.0" />
<PackageReference Include="xunit" Version="2.3.1" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.3.1" />
<DotNetCliToolReference Include="dotnet-xunit" Version="2.3.1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
</ItemGroup>

</Project>
53 changes: 53 additions & 0 deletions test/Microsoft.ML.Sweeper.Tests/SweeperUniqueValuesTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.PipelineInference;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Runtime.Sweeper;
using System;
using System.IO;
using Xunit;

namespace Microsoft.ML.Sweeper.Tests
{
public class SweeperUniqueValuesTest
{
[Fact]
public void SweeperReturnsDistinctValuesForUniformRandomSweeper()
{
DiscreteValueGenerator valueGenerator = SetSingleParameter();
using (var writer = new StreamWriter(new MemoryStream()))
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
{
var sweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase(), new[] { valueGenerator });
var results = sweeper.ProposeSweeps(5000);
Assert.NotNull(results);
int length = results.Length;
Assert.Equal(2, length);
}
}
[Fact]
public void SweeperReturnsDistinctValuesForRandomGridSweeper()
{
DiscreteValueGenerator valueGenerator = SetSingleParameter();
using (var writer = new StreamWriter(new MemoryStream()))
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
{
var sweeper = new RandomGridSweeper(env, new RandomGridSweeper.Arguments(), new[] { valueGenerator });
var results = sweeper.ProposeSweeps(5000);
Assert.NotNull(results);
int length = results.Length;
Assert.Equal(2, length);
}
}

private static DiscreteValueGenerator SetSingleParameter()
{
var args = new DiscreteParamArguments();
args.Name = "TestParam";
args.Values = new string[] { "one", "two" };
var valueGenerator = new DiscreteValueGenerator(args);
return valueGenerator;
}
}
}