Skip to content

Commit dbbc69e

Browse files
authored
Bring ensembles into codebase (#379)
Introduce Ensemble codebase
1 parent 17f944c commit dbbc69e

File tree

73 files changed

+19405
-10044
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+19405
-10044
lines changed

Microsoft.ML.sln

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper.Tests"
118118
EndProject
119119
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.LightGBM", "src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj", "{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}"
120120
EndProject
121+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Ensemble", "src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj", "{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}"
122+
EndProject
121123
Global
122124
GlobalSection(SolutionConfigurationPlatforms) = preSolution
123125
Debug|Any CPU = Debug|Any CPU
@@ -228,6 +230,10 @@ Global
228230
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Debug|Any CPU.Build.0 = Debug|Any CPU
229231
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release|Any CPU.ActiveCfg = Release|Any CPU
230232
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}.Release|Any CPU.Build.0 = Release|Any CPU
233+
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
234+
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Debug|Any CPU.Build.0 = Debug|Any CPU
235+
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Release|Any CPU.ActiveCfg = Release|Any CPU
236+
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}.Release|Any CPU.Build.0 = Release|Any CPU
231237
EndGlobalSection
232238
GlobalSection(SolutionProperties) = preSolution
233239
HideSolutionNode = FALSE
@@ -267,6 +273,7 @@ Global
267273
{9252A8EB-ABFB-440C-AB4D-1D562753CE0F} = {487213C9-E8A9-4F94-85D7-28A05DBBFE3A}
268274
{3DEB504D-7A07-48CE-91A2-8047461CB3D4} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
269275
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
276+
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
270277
EndGlobalSection
271278
GlobalSection(ExtensibilityGlobals) = postSolution
272279
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

src/Microsoft.ML.Ensemble/Batch.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
7+
namespace Microsoft.ML.Runtime.Ensemble
8+
{
9+
public sealed class Batch
10+
{
11+
public readonly RoleMappedData TrainInstances;
12+
public readonly RoleMappedData TestInstances;
13+
14+
public Batch(RoleMappedData trainData, RoleMappedData testData)
15+
{
16+
Contracts.CheckValue(trainData, nameof(trainData));
17+
Contracts.CheckValue(testData, nameof(testData));
18+
TrainInstances = trainData;
19+
TestInstances = testData;
20+
}
21+
}
22+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections;
7+
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.Runtime.Internal.Utilities;
9+
10+
namespace Microsoft.ML.Runtime.Ensemble
11+
{
12+
internal static class EnsembleUtils
13+
{
14+
/// <summary>
15+
/// Return a dataset with non-selected features zeroed out.
16+
/// </summary>
17+
public static RoleMappedData SelectFeatures(IHost host, RoleMappedData data, BitArray features)
18+
{
19+
Contracts.AssertValue(host);
20+
Contracts.AssertValue(data);
21+
Contracts.Assert(data.Schema.Feature != null);
22+
Contracts.AssertValue(features);
23+
24+
var type = data.Schema.Feature.Type;
25+
Contracts.Assert(features.Length == type.VectorSize);
26+
int card = Utils.GetCardinality(features);
27+
if (card == type.VectorSize)
28+
return data;
29+
30+
// REVIEW: This doesn't preserve metadata on the features column. Should it?
31+
var name = data.Schema.Feature.Name;
32+
var view = LambdaColumnMapper.Create(
33+
host, "FeatureSelector", data.Data, name, name, type, type,
34+
(ref VBuffer<Single> src, ref VBuffer<Single> dst) => SelectFeatures(ref src, features, card, ref dst));
35+
36+
var res = RoleMappedData.Create(view, data.Schema.GetColumnRoleNames());
37+
return res;
38+
}
39+
40+
/// <summary>
41+
/// Fill dst with values selected from src if the indices of the src values are set in includedIndices,
42+
/// otherwise assign default(T). The length of dst will be equal to src.Length.
43+
/// </summary>
44+
public static void SelectFeatures<T>(ref VBuffer<T> src, BitArray includedIndices, int cardinality, ref VBuffer<T> dst)
45+
{
46+
Contracts.Assert(Utils.Size(includedIndices) == src.Length);
47+
Contracts.Assert(cardinality == Utils.GetCardinality(includedIndices));
48+
Contracts.Assert(cardinality < src.Length);
49+
50+
var values = dst.Values;
51+
var indices = dst.Indices;
52+
53+
if (src.IsDense)
54+
{
55+
if (cardinality >= src.Length / 2)
56+
{
57+
T defaultValue = default;
58+
if (Utils.Size(values) < src.Length)
59+
values = new T[src.Length];
60+
for (int i = 0; i < src.Length; i++)
61+
values[i] = !includedIndices[i] ? defaultValue : src.Values[i];
62+
dst = new VBuffer<T>(src.Length, values, indices);
63+
}
64+
else
65+
{
66+
if (Utils.Size(values) < cardinality)
67+
values = new T[cardinality];
68+
if (Utils.Size(indices) < cardinality)
69+
indices = new int[cardinality];
70+
71+
int count = 0;
72+
for (int i = 0; i < src.Length; i++)
73+
{
74+
if (includedIndices[i])
75+
{
76+
Contracts.Assert(count < cardinality);
77+
values[count] = src.Values[i];
78+
indices[count] = i;
79+
count++;
80+
}
81+
}
82+
83+
Contracts.Assert(count == cardinality);
84+
dst = new VBuffer<T>(src.Length, count, values, indices);
85+
}
86+
}
87+
else
88+
{
89+
int valuesSize = Utils.Size(values);
90+
int indicesSize = Utils.Size(indices);
91+
if (valuesSize < src.Count || indicesSize < src.Count)
92+
{
93+
if (valuesSize < cardinality)
94+
values = new T[cardinality];
95+
if (indicesSize < cardinality)
96+
indices = new int[cardinality];
97+
}
98+
99+
int count = 0;
100+
for (int i = 0; i < src.Count; i++)
101+
{
102+
if (includedIndices[src.Indices[i]])
103+
{
104+
values[count] = src.Values[i];
105+
indices[count] = src.Indices[i];
106+
count++;
107+
}
108+
}
109+
110+
dst = new VBuffer<T>(src.Length, count, values, indices);
111+
}
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)