Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed mac build and minor torch sharp changes #6776

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build/vsts-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ stages:
pool:
vmImage: macOS-12
steps:
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && brew update && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 && rm '/usr/local/bin/2to3-3.11' && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
displayName: Install build dependencies
# Only build native assets to avoid conflicts.
- script: ./build.sh -projects $(Build.SourcesDirectory)/src/Native/Native.proj -configuration $(BuildConfig) /p:TargetArchitecture=x64 /p:CopyPackageAssets=true
Expand Down
16 changes: 14 additions & 2 deletions src/Microsoft.ML.TorchSharp/NasBert/SentenceSimilarityTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,19 @@ namespace Microsoft.ML.TorchSharp.NasBert
///
public class SentenceSimilarityTrainer : NasBertTrainer<float, float>
{
internal SentenceSimilarityTrainer(IHostEnvironment env, Options options) : base(env, options)

public class SentenceSimilarityOptions : NasBertOptions
{
public SentenceSimilarityOptions()
{
BatchSize = 32;
MaxEpoch = 10;
TaskType = BertTaskType.SentenceRegression;
LearningRate = new List<double>() { .0002 };
WeightDecay = .01;
}
}
internal SentenceSimilarityTrainer(IHostEnvironment env, SentenceSimilarityOptions options) : base(env, options)
{
}

Expand All @@ -71,7 +83,7 @@ internal SentenceSimilarityTrainer(IHostEnvironment env,
int maxEpochs = 10,
IDataView validationSet = null,
BertArchitecture architecture = BertArchitecture.Roberta) :
this(env, new NasBertOptions
this(env, new SentenceSimilarityOptions
{
ScoreColumnName = scoreColumnName,
Sentence1ColumnName = sentence1ColumnName,
Expand Down
14 changes: 12 additions & 2 deletions src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,17 @@ namespace Microsoft.ML.TorchSharp.NasBert
///
public class TextClassificationTrainer : NasBertTrainer<UInt32, long>
{
internal TextClassificationTrainer(IHostEnvironment env, NasBertOptions options) : base(env, options)
public class TextClassificationOptions : NasBertTrainer.NasBertOptions
{
public TextClassificationOptions()
{
TaskType = BertTaskType.TextClassification;
BatchSize = 32;
MaxEpoch = 10;
}
}

internal TextClassificationTrainer(IHostEnvironment env, TextClassificationOptions options) : base(env, options)
{
}

Expand All @@ -74,7 +84,7 @@ internal TextClassificationTrainer(IHostEnvironment env,
int maxEpochs = 10,
IDataView validationSet = null,
BertArchitecture architecture = BertArchitecture.Roberta) :
this(env, new NasBertOptions
this(env, new TextClassificationOptions
{
PredictionColumnName = predictionColumnName,
ScoreColumnName = scoreColumnName,
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static TextClassificationTrainer TextClassification(
/// <returns></returns>
public static TextClassificationTrainer TextClassification(
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
NasBertTrainer.NasBertOptions options)
TextClassificationTrainer.TextClassificationOptions options)
=> new TextClassificationTrainer(CatalogUtils.GetEnvironment(catalog), options);

/// <summary>
Expand Down Expand Up @@ -99,7 +99,7 @@ public static SentenceSimilarityTrainer SentenceSimilarity(
/// <returns></returns>
public static SentenceSimilarityTrainer SentenceSimilarity(
this RegressionCatalog.RegressionTrainers catalog,
NasBertTrainer.NasBertOptions options)
SentenceSimilarityTrainer.SentenceSimilarityOptions options)
=> new SentenceSimilarityTrainer(CatalogUtils.GetEnvironment(catalog), options);


Expand Down
4 changes: 1 addition & 3 deletions test/Microsoft.ML.Tests/TextClassificationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,12 @@ public void TestSentenceSimilarityLargeFileGpu()

var dataSplit = ML.Data.TrainTestSplit(dataView, testFraction: 0.2);

var options = new NasBertTrainer.NasBertOptions()
var options = new SentenceSimilarityTrainer.SentenceSimilarityOptions()
{
TaskType = BertTaskType.SentenceRegression,
Sentence1ColumnName = "search_term",
Sentence2ColumnName = "product_description",
LabelColumnName = "relevance",
LearningRate = new List<double>() { .0002 },
WeightDecay = .01
};

var estimator = ML.Regression.Trainers.SentenceSimilarity(options);
Expand Down
Loading