Skip to content

Commit

Permalink
removed test_visualize bc its tested by the callback, added an as yet…
Browse files Browse the repository at this point in the history
… unused enum for splits, added a test for the visualize callback that doesn't include model fit.
  • Loading branch information
bgenchel committed Aug 14, 2024
1 parent e590dfe commit f00dde0
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 48 deletions.
3 changes: 1 addition & 2 deletions basic_pitch/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
original_validation_ds: tf.data.Dataset,
contours: bool,
):
super(VisualizeCallback, self).__init__()
super().__init__()
self.train_iter = iter(train_ds)
self.validation_iter = iter(validation_ds)
self.validation_ds = original_validation_ds
Expand All @@ -44,7 +44,6 @@ def __init__(

def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
# the first two outputs of generator needs to be the input and the targets
print(f"epoch: {epoch}, logs: {logs}")
train_inputs, train_targets = next(self.train_iter)[:2]
validation_inputs, validation_targets = next(self.validation_iter)[:2]
for stage, inputs, targets, loss in [
Expand Down
6 changes: 5 additions & 1 deletion basic_pitch/constants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@

import numpy as np

from enum import Enum

FFT_HOP = 256
N_FFT = 8 * FFT_HOP

Expand Down Expand Up @@ -59,3 +61,5 @@ def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int)

FREQ_BINS_NOTES = _freq_bins(NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)
FREQ_BINS_CONTOURS = _freq_bins(CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)

Splits = Enum("Splits", ["train", "validation", "test"])
20 changes: 19 additions & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
# limitations under the License.

import numpy as np
import os
import tensorflow as tf

from typing import Dict

from basic_pitch.callbacks import VisualizeCallback
from basic_pitch.constants import AUDIO_N_SAMPLES, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


class MockModel(tf.keras.Model):
def __init__(self) -> None:
Expand All @@ -47,6 +50,21 @@ def create_mock_dataset() -> tf.data.Dataset:


def test_visualize_callback_on_epoch_end(tmpdir: str) -> None:

vc = VisualizeCallback(
train_ds=create_mock_dataset(),
validation_ds=create_mock_dataset(),
tensorboard_dir=str(tmpdir),
original_validation_ds=create_mock_dataset(),
contours=True,
)

vc.model = MockModel()

vc.on_epoch_end(1, {"loss": np.random.random(), "val_loss": np.random.random()})


def test_visualize_callback_on_epoch_end_with_model(tmpdir: str) -> None:
model = MockModel()
model.compile(optimizer="adam", loss="mse")

Expand All @@ -66,5 +84,5 @@ def test_visualize_callback_on_epoch_end(tmpdir: str) -> None:
contours=True,
)

history = model.fit(x_train, y_train, epochs=1, validation_split=0.5, callbacks=[vc])
history = model.fit(x_train, y_train, epochs=1, validation_split=0.5, callbacks=[vc], verbose=0)
assert history
44 changes: 0 additions & 44 deletions tests/test_visualize.py

This file was deleted.

0 comments on commit f00dde0

Please sign in to comment.