Skip to content

Commit

Permalink
add SentenceSimilarity sweepable estimator in AutoML (#6445)
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud authored Nov 30, 2022
1 parent 42788c4 commit b1cb564
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 14 deletions.
4 changes: 3 additions & 1 deletion src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@
"DnnFeaturizerImage",
"Naive",
"ForecastBySsa",
"TextClassification"
"TextClassifcation",
"SentenceSimilarity"
]
},
"nugetDependencies": {
Expand Down Expand Up @@ -109,6 +110,7 @@
"Microsoft.ML.Vision",
"Microsoft.ML.Transforms.Image",
"Microsoft.ML.Trainers.FastTree",
"Microsoft.ML.TorchSharp",
"Microsoft.ML.Trainers.LightGbm"
]
}
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@
"image_classification_option",
"matrix_factorization_option",
"dnn_featurizer_image_option",
"text_classification_option"
"text_classification_option",
"sentence_similarity_option"
]
},
"option_name": {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"$schema": "./search-space-schema.json#",
"name": "sentence_similarity_option",
"search_space": [
{
"name": "LabelColumnName",
"type": "string",
"default": "Label"
},
{
"name": "Sentence1ColumnName",
"type": "string",
"default": "Sentence1"
},
{
"name": "Sentence2ColumnName",
"type": "string"
},
{
"name": "ScoreColumnName",
"type": "string",
"default": "Score"
},
{
"name": "BatchSize",
"type": "integer",
"default": 32
},
{
"name": "MaxEpochs",
"type": "integer",
"default": 10
},
{
"name": "Architecture",
"type": "bertArchitecture",
"default": "BertArchitecture.Roberta"
}
]
}
19 changes: 8 additions & 11 deletions src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@
"argumentType": "boolean"
}
],
"nugetDependencies": ["Microsoft.ML"],
"nugetDependencies": [ "Microsoft.ML" ],
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers" ],
"searchOption": "lbfgs_option"
},
Expand Down Expand Up @@ -514,20 +514,17 @@
{
"functionName": "TextClassification",
"estimatorTypes": [ "MultiClassification" ],
"arguments": [
{
"argumentName": "labelColumnName",
"argumentType": "string"
},
{
"argumentName": "sentence1ColumnName",
"argumentType": "string"
}
],
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
"searchOption": "text_classification_option"
},
{
"functionName": "SentenceSimilarity",
"estimatorTypes": [ "Regression" ],
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
"searchOption": "sentence_similarity_option"
},
{
"functionName": "ForecastBySsa",
"estimatorTypes": [ "Forecasting" ],
Expand Down
1 change: 0 additions & 1 deletion src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
<AdditionalFiles Include="CodeGen\*search_space.json" />
<AdditionalFiles Include="CodeGen\code_gen_flag.json" />
<AdditionalFiles Include="CodeGen\*-estimators.json" />
<AdditionalFiles Include="CodeGen\code_gen_flag.json" />
</ItemGroup>

<Target DependsOnTargets="ResolveReferences" Name="CopyProjectReferencesToPackage">
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Reflection;
using System.Text;
using Microsoft.ML.TorchSharp;

namespace Microsoft.ML.AutoML.CodeGen
{
internal partial class SentenceSimilarityRegression
{
public override IEstimator<ITransformer> BuildFromOption(MLContext context, SentenceSimilarityOption param)
{
return context.Regression.Trainers.SentenceSimilarity(
labelColumnName: param.LabelColumnName,
sentence1ColumnName: param.Sentence1ColumnName,
scoreColumnName: param.ScoreColumnName,
sentence2ColumnName: param.Sentence2ColumnName,
batchSize: param.BatchSize,
maxEpochs: param.MaxEpochs,
architecture: param.Architecture);
}
}
}

0 comments on commit b1cb564

Please sign in to comment.