Skip to content

Commit 4d88ca8

Browse files
committed
Changed List to HashSet to ensure that there are no duplicates
1 parent c49f9f9 commit 4d88ca8

File tree

3 files changed

+49
-11
lines changed

3 files changed

+49
-11
lines changed

src/Microsoft.ML.Core/Prediction/ISweeper.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ public override string ToString()
174174
{
175175
return string.Join(" ", _parameterValues.Select(kvp => string.Format("{0}={1}", kvp.Value.Name, kvp.Value.ValueText)).ToArray());
176176
}
177+
178+
public override int GetHashCode()
179+
{
180+
return _hash;
181+
}
177182
}
178183

179184
/// <summary>

src/Microsoft.ML.Sweeper/Algorithms/Grid.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ protected SweeperBase(ArgumentsBase args, IHostEnvironment env, IValueGenerator[
6767
public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
6868
{
6969
var prevParamSets = previousRuns?.Select(r => r.ParameterSet).ToList() ?? new List<ParameterSet>();
70-
var result = new List<ParameterSet>();
70+
var result = new HashSet<ParameterSet>();
7171
for (int i = 0; i < maxSweeps; i++)
7272
{
7373
ParameterSet paramSet;
@@ -150,12 +150,12 @@ public RandomGridSweeper(IHostEnvironment env, Arguments args, IValueGenerator[]
150150
}
151151
}
152152

153-
public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns)
153+
public override ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable<IRunResult> previousRuns = null)
154154
{
155155
if (_nGridPoints == 0)
156156
return base.ProposeSweeps(maxSweeps, previousRuns);
157157

158-
var result = new List<ParameterSet>();
158+
var result = new HashSet<ParameterSet>();
159159
var prevParamSets = (previousRuns != null)
160160
? previousRuns.Select(r => r.ParameterSet).ToList()
161161
: new List<ParameterSet>();

test/Microsoft.ML.Sweeper.Tests/SweeperTest.cs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,54 @@ namespace Microsoft.ML.Sweeper.Tests
1212
public class SweeperTest
1313
{
1414
[Fact]
15-
public void SweeperReturnsDistinctValues()
15+
public void UniformRandomSweeperReturnsDistinctValuesWhenProposeSweep()
1616
{
17-
var args = new DiscreteParamArguments();
18-
args.Name = "Amazing";
19-
args.Values = new string[] { "one" };
20-
var valueGenerator = new DiscreteValueGenerator(args);
17+
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();
18+
2119
using (var writer = new StreamWriter(new MemoryStream()))
2220
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
2321
{
24-
var sweeper = new UniformRandomSweeper(env, new SweeperBase.ArgumentsBase(), new[] { valueGenerator });
25-
var results = sweeper.ProposeSweeps(2);
22+
var sweeper = new UniformRandomSweeper(env,
23+
new SweeperBase.ArgumentsBase(),
24+
new[] { valueGenerator });
25+
26+
var results = sweeper.ProposeSweeps(5000);
2627
Assert.NotNull(results);
28+
2729
int length = results.Length;
28-
Assert.Equal(1, length);
30+
Assert.Equal(2, length);
2931
}
3032
}
33+
34+
[Fact]
35+
public void RandomGridSweeperReturnsDistinctValuesWhenProposeSweep()
36+
{
37+
DiscreteValueGenerator valueGenerator = CreateDiscreteValueGenerator();
38+
39+
using (var writer = new StreamWriter(new MemoryStream()))
40+
using (var env = new TlcEnvironment(42, outWriter: writer, errWriter: writer))
41+
{
42+
var sweeper = new RandomGridSweeper(env,
43+
new RandomGridSweeper.Arguments(),
44+
new[] { valueGenerator });
45+
46+
var results = sweeper.ProposeSweeps(5000);
47+
Assert.NotNull(results);
48+
49+
int length = results.Length;
50+
Assert.Equal(2, length);
51+
}
52+
}
53+
54+
private static DiscreteValueGenerator CreateDiscreteValueGenerator()
55+
{
56+
var args = new DiscreteParamArguments()
57+
{
58+
Name = "TestParam",
59+
Values = new string[] { "one", "two" }
60+
};
61+
62+
return new DiscreteValueGenerator(args);
63+
}
3164
}
3265
}

0 commit comments

Comments
 (0)