Skip to content

Commit

Permalink
moved test_visualize out of tests/data dir for the time being, create…
Browse files Browse the repository at this point in the history
…d tests for the visualize callback and the visualize method, removed the save model callback bc it wasn't being used.
  • Loading branch information
bgenchel committed Aug 13, 2024
1 parent 44d9e19 commit e590dfe
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 83 deletions.
25 changes: 2 additions & 23 deletions basic_pitch/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

from typing import Any, Dict
Expand All @@ -25,27 +24,6 @@
from basic_pitch import visualize


class SavedModelCallback(tf.keras.callbacks.Callback):
def __init__(self, output_path: str, monitor: str):
self.output_savemodel_path = output_path
self.monitor = monitor # 'val_loss' typically
self.best_loss_so_far = 1000000.0

def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
loss = logs.get(self.monitor)
if loss is None:
logging.warning("SaveModelCallback: monitored variable %s is not defined in logs, skipping" % self.monitor)
else:
if loss < self.best_loss_so_far:
output_path = os.path.join(
self.output_savemodel_path,
"%d/model" % epoch,
)
logging.info("SaveModelCallback: saving model at iteration %d in %s" % (epoch, output_path))
tf.saved_model.save(self.model, output_path)
self.best_loss_so_far = loss


class VisualizeCallback(tf.keras.callbacks.Callback):
# TODO RACHEL make this WAY faster
def __init__(
Expand All @@ -66,6 +44,7 @@ def __init__(

def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
# the first two outputs of generator needs to be the input and the targets
print(f"epoch: {epoch}, logs: {logs}")
train_inputs, train_targets = next(self.train_iter)[:2]
validation_inputs, validation_targets = next(self.validation_iter)[:2]
for stage, inputs, targets, loss in [
Expand Down
60 changes: 0 additions & 60 deletions tests/data/test_visualize.py

This file was deleted.

70 changes: 70 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import tensorflow as tf

from typing import Dict

from basic_pitch.callbacks import VisualizeCallback
from basic_pitch.constants import AUDIO_N_SAMPLES, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES


class MockModel(tf.keras.Model):
def __init__(self) -> None:
super(MockModel, self).__init__()

def call(self, inputs: tf.Tensor) -> Dict[str, tf.Tensor]:
return {
key: tf.random.normal((1, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES)) for key in ["onset", "contour", "note"]
}


def create_mock_dataset() -> tf.data.Dataset:
batch_size = 1
inputs = tf.random.normal((batch_size, AUDIO_N_SAMPLES, 1))
targets = {
key: tf.random.normal((batch_size, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES))
for key in ["onset", "contour", "note"]
}
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.batch(batch_size)
return dataset


def test_visualize_callback_on_epoch_end(tmpdir: str) -> None:
model = MockModel()
model.compile(optimizer="adam", loss="mse")

batch_size = 2 # needs to be at least 2 bc validation_split required

x_train = np.random.random((batch_size, AUDIO_N_SAMPLES, 1))
y_train = {
key: np.random.random((batch_size, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES))
for key in ["onset", "contour", "note"]
}

vc = VisualizeCallback(
train_ds=create_mock_dataset(),
validation_ds=create_mock_dataset(),
tensorboard_dir=str(tmpdir),
original_validation_ds=create_mock_dataset(),
contours=True,
)

history = model.fit(x_train, y_train, epochs=1, validation_split=0.5, callbacks=[vc])
assert history
44 changes: 44 additions & 0 deletions tests/test_visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import tensorflow as tf

from basic_pitch.constants import AUDIO_N_SAMPLES, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES
from basic_pitch.visualize import visualize_transcription


def test_visualize_transcription(tmpdir: str) -> None:
inputs = tf.random.normal([1, AUDIO_N_SAMPLES, 1])
targets = {
key: tf.random.normal([1, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES]) for key in ["onset", "contour", "note"]
}
outputs = {
key: tf.random.normal([1, ANNOTATIONS_N_SEMITONES, ANNOT_N_FRAMES]) for key in ["onset", "contour", "note"]
}

visualize_transcription(
file_writer=tf.summary.create_file_writer(str(tmpdir)),
stage="train",
inputs=inputs,
targets=targets,
outputs=outputs,
loss=np.random.random(),
step=1,
sonify=True,
contours=True,
)

0 comments on commit e590dfe

Please sign in to comment.