Skip to content

Commit

Permalink
reviewing more TODO comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sammlapp committed Oct 7, 2024
1 parent 5f858b0 commit 2d50b2d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 0 additions & 2 deletions opensoundscape/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def training_step(self, samples, batch_idx):
batch_size = len(batch_tensors)

# automatic mixed precision
# TODO: add tests with self.use_amp=True
# can get rid of if/else blocks and use enabled=true
# once mps is supported https://github.com/pytorch/pytorch/pull/99272

Expand Down Expand Up @@ -290,7 +289,6 @@ def training_step(self, samples, batch_idx):
batch_metrics, on_epoch=True, on_step=False, batch_size=batch_size
)
# when on_epoch=True, compute() is called to reset the metric at epoch end
# TODO: log this somehow when not in lightning_mode?

return loss

Expand Down
9 changes: 6 additions & 3 deletions opensoundscape/ml/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def fit_with_trainer(
accelerator, precision, logger, accumulate_grad_batches, etc.
Note: the `max_epochs` kwarg is overridden by the `epochs` argument
Returns:
a trained pytorch_lightning.Trainer object
Effects:
If wandb_session is provided, logs progress and samples to Weights
and Biases. A random set of training and validation samples
Expand Down Expand Up @@ -434,7 +437,7 @@ def predict_with_trainer(
num_workers=num_workers,
raise_errors=raise_errors,
**dataloader_kwargs,
) # TODO: add test for kwargs
)

# check for matching class list
if len(dataloader.dataset.dataset.classes) > 0 and list(self.classes) != list(
Expand All @@ -444,6 +447,7 @@ def predict_with_trainer(
"The columns of input samples df differ from `model.classes`."
)

# Could re-add logging samples to wandb table:
# if wandb_session is not None:
# # update the run config with information about the model
# wandb_session.config.update(self._generate_wandb_config())
Expand All @@ -469,11 +473,10 @@ def predict_with_trainer(

### Prediction/Inference ###
# iterate dataloader and run inference (forward pass) to generate scores
# TODO: add test for kwargs
trainer = L.Trainer(**lightning_trainer_kwargs)
pred_scores = torch.vstack(trainer.predict(self, dataloader))

### Apply activation layer ### #TODO: test speed vs. doing it in __call__ on batches
### Apply activation layer ###
pred_scores = apply_activation_layer(pred_scores, activation_layer)

# return DataFrame with same index/columns as prediction_dataset's df
Expand Down

0 comments on commit 2d50b2d

Please sign in to comment.