Skip to content

Commit

Permalink
fixed mac build and minor torch sharp changes (#6776)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelgsharp committed Jul 28, 2023
1 parent 7b6af06 commit 8952994
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
2 changes: 1 addition & 1 deletion build/vsts-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ stages:
pool:
vmImage: macOS-12
steps:
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && brew update && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
displayName: Install build dependencies
# Only build native assets to avoid conflicts.
- script: ./build.sh -projects $(Build.SourcesDirectory)/src/Native/Native.proj -configuration $(BuildConfig) /p:TargetArchitecture=x64 /p:CopyPackageAssets=true
Expand Down
16 changes: 14 additions & 2 deletions src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,19 @@ namespace Microsoft.ML.TorchSharp.NasBert
///
public class SentenceSimilarityTrainer : NasBertTrainer<float, float>
{
internal SentenceSimilarityTrainer(IHostEnvironment env, Options options) : base(env, options)

public class SentenceSimilarityOptions : NasBertOptions
{
public SentenceSimilarityOptions()
{
BatchSize = 32;
MaxEpoch = 10;
TaskType = BertTaskType.SentenceRegression;
LearningRate = new List<double>() { .0002 };
WeightDecay = .01;
}
}
internal SentenceSimilarityTrainer(IHostEnvironment env, SentenceSimilarityOptions options) : base(env, options)
{
}

Expand All @@ -71,7 +83,7 @@ internal SentenceSimilarityTrainer(IHostEnvironment env,
int maxEpochs = 10,
IDataView validationSet = null,
BertArchitecture architecture = BertArchitecture.Roberta) :
this(env, new NasBertOptions
this(env, new SentenceSimilarityOptions
{
ScoreColumnName = scoreColumnName,
Sentence1ColumnName = sentence1ColumnName,
Expand Down
14 changes: 12 additions & 2 deletions src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,17 @@ namespace Microsoft.ML.TorchSharp.NasBert
///
public class TextClassificationTrainer : NasBertTrainer<UInt32, long>
{
internal TextClassificationTrainer(IHostEnvironment env, NasBertOptions options) : base(env, options)
public class TextClassificationOptions : NasBertTrainer.NasBertOptions
{
public TextClassificationOptions()
{
TaskType = BertTaskType.TextClassification;
BatchSize = 32;
MaxEpoch = 10;
}
}

internal TextClassificationTrainer(IHostEnvironment env, TextClassificationOptions options) : base(env, options)
{
}

Expand All @@ -74,7 +84,7 @@ internal TextClassificationTrainer(IHostEnvironment env,
int maxEpochs = 10,
IDataView validationSet = null,
BertArchitecture architecture = BertArchitecture.Roberta) :
this(env, new NasBertOptions
this(env, new TextClassificationOptions
{
PredictionColumnName = predictionColumnName,
ScoreColumnName = scoreColumnName,
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static TextClassificationTrainer TextClassification(
/// <returns></returns>
public static TextClassificationTrainer TextClassification(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NasBertTrainer.NasBertOptions options)
TextClassificationTrainer.TextClassificationOptions options)
=> new TextClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options);

/// <summary>
Expand Down Expand Up @@ -99,7 +99,7 @@ public static SentenceSimilarityTrainer SentenceSimilarity(
/// <returns></returns>
public static SentenceSimilarityTrainer SentenceSimilarity(
this RegressionCatalog.RegressionTrainers catalog,
NasBertTrainer.NasBertOptions options)
SentenceSimilarityTrainer.SentenceSimilarityOptions options)
=> new SentenceSimilarityTrainer(CatalogUtils.GetEnvironment(catalog), options);


Expand Down
4 changes: 1 addition & 3 deletions test/Microsoft.ML.Tests/TextClassificationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,12 @@ public void TestSentenceSimilarityLargeFileGpu()

var dataSplit = ML.Data.TrainTestSplit(dataView, testFraction: 0.2);

var options = new NasBertTrainer.NasBertOptions()
var options = new SentenceSimilarityTrainer.SentenceSimilarityOptions()
{
TaskType = BertTaskType.SentenceRegression,
Sentence1ColumnName = "search_term",
Sentence2ColumnName = "product_description",
LabelColumnName = "relevance",
LearningRate = new List<double>() { .0002 },
WeightDecay = .01
};

var estimator = ML.Regression.Trainers.SentenceSimilarity(options);
Expand Down

0 comments on commit 8952994

Please sign in to comment.