Skip to content

Commit

Permalink
updated train method to use generators as params
Browse files Browse the repository at this point in the history
  • Loading branch information
lalouikarim authored May 23, 2024
1 parent 615bd7c commit 194e932
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions mrcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2272,7 +2272,7 @@ def set_log_dir(self, model_path=None):
self.checkpoint_path = self.checkpoint_path.replace(
"*epoch*", "{epoch:04d}")

def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
def train(self, train_generator, val_generator, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=None, no_augmentation_sources=None):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
Expand Down Expand Up @@ -2322,14 +2322,6 @@ def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
if layers in layer_regex.keys():
layers = layer_regex[layers]

# Data generators
train_generator = data_generator(train_dataset, self.config, shuffle=True,
augmentation=augmentation,
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)
val_generator = data_generator(val_dataset, self.config, shuffle=True,
batch_size=self.config.BATCH_SIZE)

# Create log_dir if it does not exist
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
Expand Down

0 comments on commit 194e932

Please sign in to comment.