Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert hack that leads to OOM during fine-tuning #3858

Merged
merged 2 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Revert hack that leads to OOM during fine-tuning
  • Loading branch information
arnavgarg1 committed Jan 4, 2024
commit 4c61bafaf33e91b57d9e95f932a6652e689856a3
35 changes: 0 additions & 35 deletions ludwig/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ def prepare(
"""
pass

def eval(self, model: nn.Module):
model.eval()

def train(self, model: nn.Module, prev_model_training_mode: bool = None):
if prev_model_training_mode is not None:
model.train(prev_model_training_mode)
else:
model.train()

def prepare_for_inference(self, model: nn.Module) -> nn.Module:
return model

Expand Down Expand Up @@ -207,10 +198,6 @@ def replace_model_from_serialization(cls, state: nn.Module | tuple[nn.Module, li


class LocalStrategy(DistributedStrategy):
def __init__(self):
super().__init__()
self.module_name_to_prev_training_mode = {}

def prepare(
self,
model: nn.Module,
Expand All @@ -219,28 +206,6 @@ def prepare(
) -> tuple[nn.Module, Optimizer]:
return model, create_optimizer(model, trainer_config.optimizer, base_learning_rate)

def eval(self, model):
# HACK(geoffrey): use vanilla model.eval()
# when https://github.com/huggingface/transformers/issues/28023 is resolved.
for module_name, module in model.named_modules():
self.module_name_to_prev_training_mode[module_name] = module.training
module.eval()

def train(self, model, prev_model_training_mode=None):
"""If mode is None, restore previous training mode."""
# HACK(geoffrey): use vanilla model.train(prev_model_training_mode)
# when https://github.com/huggingface/transformers/issues/28023 is resolved.
# This hack ignores module.training updates if the model is already in training mode
# (to avoid touching LoRA configuration). Otherwise, the model was in eval mode, so we
# restore the previous training mode. We do not use prev_model_training_mode because we store the history
# as a dictionary mapping to training mode to each module.
if model.training:
return

for module_name, module in model.named_modules():
if module_name in self.module_name_to_prev_training_mode:
module.train(self.module_name_to_prev_training_mode[module_name])

def size(self) -> int:
return 1

Expand Down
34 changes: 16 additions & 18 deletions ludwig/models/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def __init__(

def batch_predict(self, dataset: Dataset, dataset_name: str = None, collect_logits: bool = False):
self.dist_model = self._distributed.to_device(self.dist_model)
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(self._batch_size, should_shuffle=False) as batcher:
Expand All @@ -151,13 +151,13 @@ def batch_predict(self, dataset: Dataset, dataset_name: str = None, collect_logi
# consolidate predictions from each batch to a single tensor
self._concat_preds(predictions)

self._distributed.train(self.dist_model, prev_model_training_mode)
self.dist_model.train(prev_model_training_mode)

return from_numpy_dataset(predictions)

def predict_single(self, batch, collect_logits: bool = False):
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
predictions = defaultdict(list)
Expand All @@ -167,8 +167,8 @@ def predict_single(self, batch, collect_logits: bool = False):
)
self._concat_preds(predictions)

self._distributed.train(self.dist_model, prev_model_training_mode)

# reset model to its original training mode
self.dist_model.train(prev_model_training_mode)
return from_numpy_dataset(predictions)

def _predict(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
Expand Down Expand Up @@ -217,8 +217,8 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
collect_predictions, collect_logits.
"""
self.dist_model = self._distributed.to_device(self.dist_model)
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(
Expand Down Expand Up @@ -289,16 +289,16 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
metrics = self.model.get_metrics()
self.model.reset_metrics()

self._distributed.train(self.dist_model, prev_model_training_mode)
self.dist_model.train(prev_model_training_mode) # Restores previous model training mode.

return metrics, from_numpy_dataset(predictions)

def batch_collect_activations(self, layer_names, dataset, bucketing_field=None):
if bucketing_field:
raise ValueError("BucketedBatcher is not supported yet")

prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(
Expand Down Expand Up @@ -328,7 +328,7 @@ def batch_collect_activations(self, layer_names, dataset, bucketing_field=None):

progress_bar.close()

self._distributed.train(self.dist_model, prev_model_training_mode)
self.dist_model.train(prev_model_training_mode) # Restores previous model training mode.

return collected_tensors

Expand Down Expand Up @@ -361,9 +361,8 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
dictionary are "inputs", "targets", and "outputs". The values of each of these keys are dictionaries of
feature names to lists of tensors. The tensors are the inputs, targets, and outputs for each batch.
"""
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)

prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode
example_inputs = defaultdict(list)
example_targets = defaultdict(list)
example_outputs = defaultdict(list)
Expand Down Expand Up @@ -455,8 +454,7 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
"outputs": example_outputs,
}

self._distributed.train(self.dist_model, prev_model_training_mode)

self.dist_model.train(prev_model_training_mode) # Restores previous model training mode.
return metrics, from_numpy_dataset(predictions), input_target_output_dict


Expand Down
3 changes: 2 additions & 1 deletion ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,8 @@ def train(
# epoch init
start_time = time.time()

self.distributed.train(self.dist_model)
# Reset the metrics at the start of the next epoch
self.dist_model.train() # Sets model to training mode.
self.model.reset_metrics()

self.callback(lambda c: c.on_epoch_start(self, progress_tracker, save_path))
Expand Down