From 125b6d5d3fdf8cad8d401bf3bf55d591f31aaa56 Mon Sep 17 00:00:00 2001 From: zewditu Hailemariam <36615490+zewditu@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:36:36 -0800 Subject: [PATCH] Add sweepable estimator to NER (#6965) --- .../CodeGen/estimator-schema.json | 3 +- ...named_entity_recognition_search_space.json | 36 +++++++++++++++++++ .../CodeGen/search-space-schema.json | 6 ++-- .../CodeGen/trainer-estimators.json | 7 ++++ .../Estimators/NamedEntityRecognitionMulti.cs | 26 ++++++++++++++ 5 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 src/Microsoft.ML.AutoML/CodeGen/named_entity_recognition_search_space.json create mode 100644 src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/NamedEntityRecognitionMulti.cs diff --git a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json index 8c12390426..048d707214 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json +++ b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json @@ -74,7 +74,8 @@ "TextClassifcation", "SentenceSimilarity", "ObjectDetection", - "QuestionAnswering" + "QuestionAnswering", + "NamedEntityRecognition" ] }, "nugetDependencies": { diff --git a/src/Microsoft.ML.AutoML/CodeGen/named_entity_recognition_search_space.json b/src/Microsoft.ML.AutoML/CodeGen/named_entity_recognition_search_space.json new file mode 100644 index 0000000000..bd5a66036d --- /dev/null +++ b/src/Microsoft.ML.AutoML/CodeGen/named_entity_recognition_search_space.json @@ -0,0 +1,36 @@ +{ + "$schema": "./search-space-schema.json#", + "name": "named_entity_recognition_option", + "search_space": [ + { + "name": "PredictionColumnName", + "type": "string", + "default": "predictedLabel" + }, + { + "name": "LabelColumnName", + "type": "string", + "default": "Label" + }, + { + "name": "Sentence1ColumnName", + "type": "string", + "default": "Sentence" + }, + { + "name": "BatchSize", + "type": "integer", + "default": 32 + }, + { + "name": "MaxEpochs", + "type": "integer", + "default": 10 + }, + { + "name": "Architecture", + "type": "bertArchitecture", + "default": "BertArchitecture.Roberta" + } + ] +} diff --git a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json index 0ccb7b1fcf..b780854c87 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json +++ b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json @@ -167,7 +167,8 @@ "text_classification_option", "sentence_similarity_option", "object_detection_option", - "question_answering_option" + "question_answering_option", + "named_entity_recognition_option" ] }, "option_name": { @@ -238,7 +239,8 @@ "AnswerIndexStartColumnName", "predictedAnswerColumnName", "TopKAnswers", - "TargetType" + "TargetType", + "PredictionColumnName" ] }, "option_type": { diff --git a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json index 0ce5a45e37..e0df321f38 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json +++ b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json @@ -539,6 +539,13 @@ "usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ], "searchOption": "question_answering_option" }, + { + "functionName": "NamedEntityRecognition", + "estimatorTypes": [ "MultiClassification" ], + "nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ], + "usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ], + "searchOption": "named_entity_recognition_option" + }, { "functionName": "ForecastBySsa", "estimatorTypes": [ "Forecasting" ], diff --git a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/NamedEntityRecognitionMulti.cs b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/NamedEntityRecognitionMulti.cs new file mode 100644 index 0000000000..8913f93d83 --- /dev/null +++ b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/NamedEntityRecognitionMulti.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.ML.TorchSharp; +using Microsoft.ML.TorchSharp.NasBert; + +namespace Microsoft.ML.AutoML.CodeGen +{ + internal partial class NamedEntityRecognitionMulti + { + public override IEstimator BuildFromOption(MLContext context, NamedEntityRecognitionOption param) + { + return context.MulticlassClassification.Trainers.NamedEntityRecognition( + labelColumnName: param.LabelColumnName, + outputColumnName: param.PredictionColumnName, + sentence1ColumnName: param.Sentence1ColumnName, + batchSize: param.BatchSize, + maxEpochs: param.MaxEpochs, + architecture: BertArchitecture.Roberta); + } + } +}