diff --git a/basic_pitch/callbacks.py b/basic_pitch/callbacks.py index ca575ec..c23d883 100644 --- a/basic_pitch/callbacks.py +++ b/basic_pitch/callbacks.py @@ -34,7 +34,7 @@ def __init__( original_validation_ds: tf.data.Dataset, contours: bool, ): - super(VisualizeCallback, self).__init__() + super().__init__() self.train_iter = iter(train_ds) self.validation_iter = iter(validation_ds) self.validation_ds = original_validation_ds @@ -44,7 +44,6 @@ def __init__( def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None: # the first two outputs of generator needs to be the input and the targets - print(f"epoch: {epoch}, logs: {logs}") train_inputs, train_targets = next(self.train_iter)[:2] validation_inputs, validation_targets = next(self.validation_iter)[:2] for stage, inputs, targets, loss in [ diff --git a/basic_pitch/constants.py b/basic_pitch/constants.py index a78a487..0123dec 100644 --- a/basic_pitch/constants.py +++ b/basic_pitch/constants.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # encoding: utf-8 # -# Copyright 2022 Spotify AB +# Copyright 2024 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ import numpy as np +from enum import Enum + FFT_HOP = 256 N_FFT = 8 * FFT_HOP @@ -59,3 +61,5 @@ def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int) FREQ_BINS_NOTES = _freq_bins(NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES) FREQ_BINS_CONTOURS = _freq_bins(CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES) + +Splits = Enum("Splits", ["train", "validation", "test"]) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 9bd49d6..bf3929e 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -16,6 +16,7 @@ # limitations under the License. import numpy as np +import os import tensorflow as tf from typing import Dict @@ -23,6 +24,8 @@ from basic_pitch.callbacks import VisualizeCallback from basic_pitch.constants import AUDIO_N_SAMPLES, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + class MockModel(tf.keras.Model): def __init__(self) -> None: @@ -47,6 +50,21 @@ def create_mock_dataset() -> tf.data.Dataset: def test_visualize_callback_on_epoch_end(tmpdir: str) -> None: + + vc = VisualizeCallback( + train_ds=create_mock_dataset(), + validation_ds=create_mock_dataset(), + tensorboard_dir=str(tmpdir), + original_validation_ds=create_mock_dataset(), + contours=True, + ) + + vc.model = MockModel() + + vc.on_epoch_end(1, {"loss": np.random.random(), "val_loss": np.random.random()}) + + +def test_visualize_callback_on_epoch_end_with_model(tmpdir: str) -> None: model = MockModel() model.compile(optimizer="adam", loss="mse") @@ -66,5 +84,5 @@ def test_visualize_callback_on_epoch_end(tmpdir: str) -> None: contours=True, ) - history = model.fit(x_train, y_train, epochs=1, validation_split=0.5, callbacks=[vc]) + history = model.fit(x_train, y_train, epochs=1, validation_split=0.5, callbacks=[vc], verbose=0) assert history diff --git a/tests/test_visualize.py b/tests/test_visualize.py deleted file mode 100644 index 0917bd4..0000000 --- a/tests/test_visualize.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python -# encoding: utf-8 -# -# Copyright 2024 Spotify AB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import tensorflow as tf - -from basic_pitch.constants import AUDIO_N_SAMPLES, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES -from basic_pitch.visualize import visualize_transcription - - -def test_visualize_transcription(tmpdir: str) -> None: - inputs = tf.random.normal([1, AUDIO_N_SAMPLES, 1]) - targets = { - key: tf.random.normal([1, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES]) for key in ["onset", "contour", "note"] - } - outputs = { - key: tf.random.normal([1, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES]) for key in ["onset", "contour", "note"] - } - - visualize_transcription( - file_writer=tf.summary.create_file_writer(str(tmpdir)), - stage="train", - inputs=inputs, - targets=targets, - outputs=outputs, - loss=np.random.random(), - step=1, - sonify=True, - contours=True, - )