Skip to content

Add NameEntityRecognition and Q&A deep learning tasks. #6760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/ci/job-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
steps:
# Extra MacOS step required to install OS-specific dependencies
- ${{ if and(contains(parameters.pool.vmImage, 'macOS'), not(contains(parameters.name, 'cross'))) }}:
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=TRUE && brew update && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=TRUE && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
displayName: Install MacOS build dependencies
# Extra Apple MacOS step required to install OS-specific dependencies
- ${{ if and(contains(parameters.pool.vmImage, 'macOS'), contains(parameters.name, 'cross')) }}:
Expand Down
13 changes: 13 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
return null;
}

/// <summary>
/// Map the tokenized Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException();

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
Expand Down Expand Up @@ -443,6 +451,11 @@ internal List<Token> TokenizeWithCache(string sequence)
return tokens;
}

public override bool IsValidChar(char ch)
{
throw new NotImplementedException();
}

internal static readonly List<Token> EmptyTokensList = new();
}
}
27 changes: 27 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
public override string? IdToToken(int id, bool skipSpecialTokens = false) =>
skipSpecialTokens && id < 0 ? null : _vocabReverse.TryGetValue(id, out var value) ? value : null;

/// <summary>
/// Map the tokenized Id to the original string.
/// </summary>
/// <param name="id">The Id to map to the string.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToString(int id, bool skipSpecialTokens = false)
{
if (skipSpecialTokens && id < 0)
return null;
if (_vocabReverse.TryGetValue(id, out var value))
{
var textChars = string.Join("", value)
.Where(c => _unicodeToByte.ContainsKey(c))
.Select(c => _unicodeToByte[c]);
var text = new string(textChars.ToArray());
return text;
}

return null;
}

/// <summary>
/// Save the model data into the vocabulary, merges, and occurrence mapping files.
/// </summary>
Expand Down Expand Up @@ -565,6 +587,11 @@ private List<Token> BpeToken(Span<char> token, Span<int> indexMapping)

return pairs;
}

public override bool IsValidChar(char ch)
{
return _byteToUnicode.ContainsKey(ch);
}
}

/// <summary>
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public abstract class Model
/// <returns>The mapped token of the Id.</returns>
public abstract string? IdToToken(int id, bool skipSpecialTokens = false);

public abstract string? IdToString(int id, bool skipSpecialTokens = false);

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
Expand All @@ -57,6 +59,14 @@ public abstract class Model
/// Gets a trainer object to use in training the model.
/// </summary>
public abstract Trainer? GetTrainer();

/// <summary>
/// Return true if the char is valid in the tokenizer; otherwise return false.
/// </summary>
/// <param name="ch"></param>
/// <returns></returns>
public abstract bool IsValidChar(char ch);

}

}
10 changes: 9 additions & 1 deletion src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ public TokenizerResult Encode(string sequence)

foreach (int id in ids)
{
tokens.Add(Model.IdToToken(id) ?? "");
if (Model.GetType() == typeof(EnglishRoberta))
tokens.Add(Model.IdToString(id) ?? "");
else
tokens.Add(Model.IdToToken(id) ?? "");
}

return Decoder?.Decode(tokens) ?? string.Join("", tokens);
Expand Down Expand Up @@ -187,5 +190,10 @@ public void TrainFromFiles(
// To Do: support added vocabulary in the tokenizer which will include this returned special_tokens.
// self.add_special_tokens(&special_tokens);
}

public bool IsValidChar(char ch)
{
return Model.IsValidChar(ch);
}
}
}
36 changes: 28 additions & 8 deletions src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,12 @@
using static TorchSharp.torch.optim.lr_scheduler;
using Microsoft.ML.TorchSharp.Utils;
using Microsoft.ML;
using Microsoft.ML.TorchSharp.NasBert;
using System.IO;
using Microsoft.ML.Data.IO;
using Microsoft.ML.TorchSharp.Loss;
using Microsoft.ML.Transforms.Image;
using static Microsoft.ML.TorchSharp.AutoFormerV2.ObjectDetectionTrainer;
using Microsoft.ML.TorchSharp.AutoFormerV2;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.TorchSharp.Extensions;
using Microsoft.ML.TorchSharp.NasBert.Models;
using static Microsoft.ML.TorchSharp.NasBert.NasBertTrainer;
using TorchSharp.Modules;
using System.Text;
using static Microsoft.ML.Data.AnnotationUtils;

[assembly: LoadableClass(typeof(ObjectDetectionTransformer), null, typeof(SignatureLoadModel),
Expand Down Expand Up @@ -503,7 +496,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
}
}

public class ObjectDetectionTransformer : RowToRowTransformerBase
public class ObjectDetectionTransformer : RowToRowTransformerBase, IDisposable
{
private protected readonly Device Device;
private protected readonly AutoFormerV2 Model;
Expand All @@ -522,6 +515,7 @@ public class ObjectDetectionTransformer : RowToRowTransformerBase

private static readonly FuncStaticMethodInfo1<object, Delegate> _decodeInitMethodInfo
= new FuncStaticMethodInfo1<object, Delegate>(DecodeInit<int>);
private bool _disposedValue;

internal ObjectDetectionTransformer(IHostEnvironment env, ObjectDetectionTrainer.Options options, AutoFormerV2 model, DataViewSchema.DetachedColumn labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTransformer)))
Expand Down Expand Up @@ -992,5 +986,31 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
return col => (activeOutput(0) || activeOutput(1) || activeOutput(2)) && _inputColIndices.Any(i => i == col);
}
}

protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
}

Model.Dispose();
_disposedValue = true;
}
}

~ObjectDetectionTransformer()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: false);
}

public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
}
16 changes: 16 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/BertModelType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.TorchSharp.NasBert
{
internal enum BertModelType
{
NasBert,
Roberta
}
}
4 changes: 3 additions & 1 deletion src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public enum BertTaskType
None = 0,
MaskedLM = 1,
TextClassification = 2,
SentenceRegression = 3
SentenceRegression = 3,
NameEntityRecognition = 4,
QuestionAnswering = 5
}
}
3 changes: 0 additions & 3 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/BaseHead.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
Expand Down
11 changes: 5 additions & 6 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/BaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.TorchSharp.Utils;
using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal abstract class BaseModel : torch.nn.Module<torch.Tensor, torch.Tensor, torch.Tensor>
{
protected readonly NasBertTrainer.NasBertOptions Options;
public BertTaskType HeadType => Options.TaskType;
public BertModelType EncoderType => Options.ModelType;

//public ModelType EncoderType => Options.ModelType;
public BertTaskType HeadType => Options.TaskType;

#pragma warning disable CA1024 // Use properties where appropriate: Modules should be fields in TorchSharp
public abstract TransformerEncoder GetEncoder();

public abstract BaseHead GetHead();

#pragma warning restore CA1024 // Use properties where appropriate

protected BaseModel(NasBertTrainer.NasBertOptions options)
Expand Down
36 changes: 36 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/ModelPrediction.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal sealed class ModelForPrediction : NasBertModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NERInferenceModel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't for NER. Its for SentenceSimilarity and TextClassification. How about TextModel? TextModelForPrediction? Thoughts?

{
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Has to match TorchSharp model.")]
private readonly PredictionHead PredictionHead;

public override BaseHead GetHead() => PredictionHead;

public ModelForPrediction(NasBertTrainer.NasBertOptions options, int padIndex, int symbolsCount, int numClasses)
: base(options, padIndex, symbolsCount)
{
PredictionHead = new PredictionHead(
inputDim: Options.EncoderOutputDim,
numClasses: numClasses,
dropoutRate: Options.PoolerDropout);
Initialize();
RegisterComponents();
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null)
{
using var disposeScope = torch.NewDisposeScope();
var x = ExtractFeatures(srcTokens);
x = PredictionHead.call(x);
return x.MoveToOuterDisposeScope();
}
}
}
Loading