Skip to content

Commit

Permalink
Split enum
Browse files Browse the repository at this point in the history
  • Loading branch information
bgenchel committed Aug 15, 2024
1 parent 06c4181 commit 435b5fa
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion basic_pitch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int)
FREQ_BINS_NOTES = _freq_bins(NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)
FREQ_BINS_CONTOURS = _freq_bins(CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES)

Splits = Enum("Splits", ["train", "validation", "test"])
Split = Enum("Split", ["train", "validation", "test"])
20 changes: 10 additions & 10 deletions basic_pitch/data/tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
AUDIO_WINDOW_LENGTH,
N_FREQ_BINS_NOTES,
N_FREQ_BINS_CONTOURS,
Split,
)

N_SAMPLES_PER_TRACK = 20
Expand All @@ -59,13 +60,13 @@ def prepare_datasets(

# init both
ds_train = sample_datasets(
"train",
Split.train,
datasets_base_path,
datasets=datasets_to_use,
dataset_sampling_frequency=dataset_sampling_frequency,
)
ds_validation = sample_datasets(
"validation",
Split.validation,
datasets_base_path,
datasets=datasets_to_use,
dataset_sampling_frequency=dataset_sampling_frequency,
Expand Down Expand Up @@ -118,14 +119,14 @@ def prepare_visualization_datasets(
assert validation_steps is not None and validation_steps > 0

ds_train = sample_datasets(
"train",
Split.train,
datasets_base_path,
datasets=datasets_to_use,
dataset_sampling_frequency=dataset_sampling_frequency,
n_samples_per_track=1,
)
ds_validation = sample_datasets(
"validation",
Split.validation,
datasets_base_path,
datasets=datasets_to_use,
dataset_sampling_frequency=dataset_sampling_frequency,
Expand Down Expand Up @@ -153,16 +154,15 @@ def prepare_visualization_datasets(


def sample_datasets(
split: str,
split: Split,
datasets_base_path: str,
datasets: List[str],
dataset_sampling_frequency: np.ndarray,
n_shuffle: int = 1000,
n_samples_per_track: int = N_SAMPLES_PER_TRACK,
pairs: bool = False,
) -> tf.data.Dataset:
assert split in ["train", "validation"]
if split == "validation":
if split == Split.validation:
n_shuffle = 0
pairs = False
if n_samples_per_track != 1:
Expand Down Expand Up @@ -209,7 +209,7 @@ def sample_datasets(


def transcription_file_generator(
split: str,
split: Split,
dataset_names: List[str],
datasets_base_path: str,
sample_weights: np.ndarray,
Expand All @@ -219,12 +219,12 @@ def transcription_file_generator(
"""
file_dict = {
dataset_name: tf.data.Dataset.list_files(
os.path.join(datasets_base_path, dataset_name, "splits", split, "*tfrecord")
os.path.join(datasets_base_path, dataset_name, "splits", split.name, "*tfrecord")
)
for dataset_name in dataset_names
}

if split == "train":
if split == Split.train:
return lambda: _train_file_generator(file_dict, sample_weights), False
return lambda: _validation_file_generator(file_dict), True

Expand Down
11 changes: 6 additions & 5 deletions tests/data/test_tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from apache_beam.testing.test_pipeline import TestPipeline
from typing import List

from basic_pitch.constants import Split
from basic_pitch.data.datasets.guitarset import GuitarSetToTfExample
from basic_pitch.data.pipeline import WriteBatchToTfRecord
from basic_pitch.data.tf_example_deserialization import (
Expand Down Expand Up @@ -135,7 +136,7 @@ def test_sample_datasets(tmp_path: pathlib.Path) -> None:
datasets_home = setup_test_resources(tmp_path)

ds = sample_datasets(
split="train",
split=Split.train,
datasets_base_path=str(datasets_home),
datasets=["guitarset"],
dataset_sampling_frequency=np.array([1]),
Expand All @@ -148,12 +149,12 @@ def test_sample_datasets(tmp_path: pathlib.Path) -> None:


def test_transcription_file_generator_train(tmp_path: pathlib.Path) -> None:
dataset_path = tmp_path / "test_ds" / "splits" / "train"
dataset_path = tmp_path / "test_ds" / "splits" / Split.train.name
dataset_path.mkdir(parents=True)
create_empty_tfrecord(dataset_path / "test.tfrecord")

file_gen, random_seed = transcription_file_generator(
"train", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
Split.train, ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
)

assert random_seed is False
Expand All @@ -167,12 +168,12 @@ def test_transcription_file_generator_train(tmp_path: pathlib.Path) -> None:


def test_transcription_file_generator_valid(tmp_path: pathlib.Path) -> None:
dataset_path = tmp_path / "test_ds" / "splits" / "valid"
dataset_path = tmp_path / "test_ds" / "splits" / Split.validation.name
dataset_path.mkdir(parents=True)
create_empty_tfrecord(dataset_path / "test.tfrecord")

file_gen, random_seed = transcription_file_generator(
"valid", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
Split.validation, ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1])
)

assert random_seed is True
Expand Down

0 comments on commit 435b5fa

Please sign in to comment.