Skip to content

Commit

Permalink
解决keras模式下,使用GPU训练时会爆显存的bug。
Browse files Browse the repository at this point in the history
观察到的现象是,一些模型增大batchsize后,会在首个epoch的中途爆显存不足,只要过了一个epoch后,就能完整训练。同样的batchsize在python下能设置大得多的值。
最后使用最小训练代码分析出,是每个step之后,图片加载到显存里的数据没有释放导致的。
在寻找释放显存接口没有结果的时候,直接使用了GC.Collect();可以让显存主动回收。
因此当前的修复方案是在每个step里,都执行一次 GC.Collect(); 用来释放显存资源。
  • Loading branch information
dogvane committed Oct 8, 2023
1 parent 5e4f530 commit baf620a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
23 changes: 23 additions & 0 deletions src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ICallback fit(NDArray x, NDArray y,
List<ICallback> callbacks = null,
float validation_split = 0f,
ValidationDataPack validation_data = null,
int validation_step = 10,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
Expand All @@ -47,6 +48,20 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
int workers = 1,
bool use_multiprocessing = false);

public ICallback fit(IDatasetV2 dataset,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
IDatasetV2 validation_data = null,
int validation_step = 10, // 间隔多少次会进行一次验证
bool shuffle = true,
Dictionary<int, float> class_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false);

void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,
Expand Down Expand Up @@ -85,6 +100,14 @@ Tensors predict(Tensors x,
int workers = 1,
bool use_multiprocessing = false);

public Tensors predict(IDatasetV2 dataset,
int batch_size = -1,
int verbose = 0,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false);

void summary(int line_length = -1, float[] positions = null);

IKerasConfig get_config();
Expand Down
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callba
var end_step = step + data_handler.StepIncrement;
if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
GC.Collect();
}
}
callbacks.on_test_end(logs);
Expand Down Expand Up @@ -167,7 +168,9 @@ Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handl
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x,y) = data_handler.DataAdapter.Expand1d(x, y);

var y_pred = Apply(x, training: false);

var loss = compiled_loss.Call(y, y_pred);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
Expand Down
12 changes: 6 additions & 6 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public ICallback fit(NDArray x, NDArray y,
List<ICallback> callbacks = null,
float validation_split = 0f,
ValidationDataPack validation_data = null,
int validation_step = 10,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
Expand Down Expand Up @@ -147,7 +148,7 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
}
}

public History fit(IDatasetV2 dataset,
public ICallback fit(IDatasetV2 dataset,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
Expand All @@ -156,7 +157,6 @@ public History fit(IDatasetV2 dataset,
int validation_step = 10,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
Expand All @@ -170,7 +170,7 @@ public History fit(IDatasetV2 dataset,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Expand Down Expand Up @@ -218,6 +218,7 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
GC.Collect();
}

if (validation_data != null)
Expand All @@ -233,11 +234,10 @@ History FitInternal(DataHandler data_handler, int epochs, int validation_step, i
callbacks.on_train_batch_end(End_step, logs);
}

GC.Collect();

callbacks.on_epoch_end(epoch, logs);

GC.Collect();
GC.WaitForPendingFinalizers();
if (stop_training)
{
break;
Expand Down Expand Up @@ -282,6 +282,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
GC.Collect();
}

if (validation_data != null)
Expand All @@ -301,7 +302,6 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
callbacks.on_epoch_end(epoch, logs);

GC.Collect();
GC.WaitForPendingFinalizers();
if (stop_training)
{
break;
Expand Down
2 changes: 1 addition & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Predict.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ Tensors PredictInternal(DataHandler data_handler, int verbose)
for (int i = 0; i < batch_outputs.Length; i++)
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
}

var end_step = step + data_handler.StepIncrement;
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
GC.Collect();
}
}

Expand Down

0 comments on commit baf620a

Please sign in to comment.