Skip to content

Commit

Permalink
add checkAlive in NasBertTrainer (#6546)
Browse files Browse the repository at this point in the history
  • Loading branch information
LittleLittleCloud committed Jan 23, 2023
1 parent eeba2ee commit a06dadc
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ public override NasBertTransformer Fit(IDataView input)
for (int i = 0; i < Option.MaxEpoch; i++)
{
ch.Trace($"Starting epoch {i}");
trainer.Train(input);
Host.CheckAlive();
trainer.Train(Host, input);
ch.Trace($"Finished epoch {i}");
if (Option.ValidationSet != null)
trainer.Validate(pch, ch, i);
Expand Down Expand Up @@ -423,7 +424,7 @@ private bool ValidateStep(DataViewRowCursor cursor,
return cursorValid;
}

public void Train(IDataView input)
public void Train(IHost host, IDataView input)
{
// Get the cursor and the correct columns based on the inputs
DataViewRowCursor cursor = default;
Expand All @@ -443,14 +444,15 @@ public void Train(IDataView input)
var cursorValid = true;
while (cursorValid)
{
cursorValid = TrainStep(cursor, sentence1Getter, sentence2Getter, labelGetter, ref inputTensors, ref targets);
cursorValid = TrainStep(host, cursor, sentence1Getter, sentence2Getter, labelGetter, ref inputTensors, ref targets);
}
}

private bool TrainStep(DataViewRowCursor cursor,
ValueGetter<ReadOnlyMemory<char>> sentence1Getter,
ValueGetter<ReadOnlyMemory<char>> sentence2Getter,
ValueGetter<TLabelCol> labelGetter,
private bool TrainStep(IHost host,
DataViewRowCursor cursor,
ValueGetter<ReadOnlyMemory<char>> sentence1Getter,
ValueGetter<ReadOnlyMemory<char>> sentence2Getter,
ValueGetter<TLabelCol> labelGetter,
ref List<Tensor> inputTensors,
ref List<TTargetsCol> targets)
{
Expand All @@ -461,6 +463,7 @@ private bool TrainStep(DataViewRowCursor cursor,
var cursorValid = true;
for (int i = 0; i < Parent.Option.BatchSize && cursorValid; i++)
{
host.CheckAlive();
cursorValid = cursor.MoveNext();
if (cursorValid)
{
Expand All @@ -479,7 +482,7 @@ private bool TrainStep(DataViewRowCursor cursor,
}

Updates++;

host.CheckAlive();
torch.random.manual_seed(1 + Updates);
torch.cuda.manual_seed(1 + Updates);
Model.train();
Expand All @@ -497,8 +500,10 @@ private bool TrainStep(DataViewRowCursor cursor,
loss = torch.nn.MSELoss(reduction: Parent.Option.Reduction).forward(logits, targetsTensor);
logits = logits.squeeze();
}

host.CheckAlive();
loss.backward();

host.CheckAlive();
OptimizeStep();

return cursorValid;
Expand Down

0 comments on commit a06dadc

Please sign in to comment.