Skip to content

Commit

Permalink
added additional docstrings for tf_example_deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
bgenchel committed Aug 15, 2024
1 parent 435b5fa commit 1a2dfa6
Showing 1 changed file with 66 additions and 7 deletions.
73 changes: 66 additions & 7 deletions basic_pitch/data/tf_example_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,16 @@ def prepare_datasets(
Return a training and a testing dataset.
Args:
training_shuffle_buffer_size : size of shuffle buffer (only for training set)
batch_size : ..
datasets_base_path: path to tfrecords for input data
training_shuffle_buffer_size: size of shuffle buffer (only for training set)
batch_size: batch size for training and validation
validation_steps: number of batches to use for validation
datasets_to_use: the underlying datasets to use for creating training and validation sets e.g. guitarset
dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
are sampled from during training / validation dataset creation.
Returns:
training and validation datasets derived from the underlying tfrecord data
"""
assert batch_size > 0
assert validation_steps is not None and validation_steps > 0
Expand Down Expand Up @@ -108,11 +116,18 @@ def prepare_visualization_datasets(
dataset_sampling_frequency: np.ndarray,
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
"""
Return a training and a testing dataset.
Return a training and a testing dataset for visualization
Args:
training_shuffle_buffer_size : size of shuffle buffer (only for training set)
batch_size : ..
datasets_base_path: path to tfrecord datasets for input data
batch_size: batch size for training and validation
validation_steps: number of batches to use for validation
datasets_to_use: the underlying datasets to use for creating training and validation sets e.g. guitarset
dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
are sampled from during training / validation dataset creation.
Returns:
training and validation datasets derived from the underlying tfrecord data
"""

assert batch_size > 0
Expand Down Expand Up @@ -162,6 +177,21 @@ def sample_datasets(
n_samples_per_track: int = N_SAMPLES_PER_TRACK,
pairs: bool = False,
) -> tf.data.Dataset:
"""samples tfrecord data to create a dataset
Args:
split: whether to use training or validation data
dataset_base_path: directory storing source datasets as tfrecord files
datasets: names of datasets to sample from e.g. guitarset
dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
are sampled from during training / validation dataset creation.
n_shuffle: size of shuffle buffer (only used for training ds)
n_samples_per_track: the number of samples to take from a track
pairs: generate pairs of samples from the dataset rather than individual samples
Returns
dataset of samples
"""
if split == Split.validation:
n_shuffle = 0
pairs = False
Expand Down Expand Up @@ -214,8 +244,14 @@ def transcription_file_generator(
datasets_base_path: str,
sample_weights: np.ndarray,
) -> Tuple[Callable[[], Iterator[tf.Tensor]], bool]:
"""
dataset_names: list of dataset dataset_names
"""Reads underlying files and returns file generator
Args:
split: data split to build generator from
dataset_names: list of dataset_names to use
dataset_base_path: directory storing source datasets as tfrecord files
sample_weights: distribution weighting vector corresponding to datasets determining how they
are sampled from during training / validation dataset creation.
"""
file_dict = {
dataset_name: tf.data.Dataset.list_files(
Expand All @@ -230,6 +266,7 @@ def transcription_file_generator(


def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) -> Iterator[tf.Tensor]:
"""file generator for training sets"""
x = {k: list(v) for (k, v) in x.items()}
keys = list(x.keys())
# shuffle each list
Expand All @@ -243,6 +280,7 @@ def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) ->


def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[tf.Tensor]:
"""file generator for validation sets"""
x = {k: list(v) for (k, v) in x.items()}
# loop until there are no more test files
while any(x.values()):
Expand All @@ -258,6 +296,13 @@ def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[tf.Ten
def combine_transcription_examples(
a: tf.Tensor, target: Dict[str, tf.Tensor], w: Dict[str, tf.Tensor]
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor], Dict[str, tf.Tensor]]:
"""mix pairs together for paired dataset
Args:
a: audio data
target: target data (onset, notes, contours)
w: weights
"""
return (
# mix the audio snippets
tf.math.reduce_mean(a, axis=0),
Expand Down Expand Up @@ -396,6 +441,7 @@ def is_not_bad_shape(
notes_onsets_shape: tf.Tensor,
_contours_shape: tf.Tensor,
) -> tf.Tensor:
"""checks for improper datashape for note values and onsets"""
bad_shape = tf.logical_and(
tf.shape(notes_values)[0] == 0,
tf.shape(notes_onsets_shape)[0] == 2,
Expand All @@ -404,6 +450,7 @@ def is_not_bad_shape(


def sparse2dense(values: tf.Tensor, indices: tf.Tensor, dense_shape: tf.Tensor) -> tf.Tensor:
"""converts sparse tensor representation to dense vector"""
if tf.rank(indices) != 2 and tf.size(indices) == 0:
indices = tf.zeros([0, 1], dtype=indices.dtype)
tf.assert_rank(indices, 2)
Expand Down Expand Up @@ -563,6 +610,17 @@ def trim_time(data: np.ndarray, start: int, duration: int, sr: int) -> tf.Tensor
def extract_window(
audio: tf.Tensor, onsets: np.ndarray, contour: np.ndarray, notes: np.ndarray, t_start: int
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
"""extracts a window of data from the given audio and its associated metadata
Args:
audio: audio signal
onsets: note onsets of audio signal
contour: pitch on off of audio signal
notes: note on off of audio signal
Returns:
tuple of windows of each of the inputs
"""
# needs a hop size extra of samples for good mel spectrogram alignment
audio_trim = trim_time(
audio,
Expand Down Expand Up @@ -665,6 +723,7 @@ def is_not_all_silent_annotations(
contour_weight: int,
note_weight: int,
) -> tf.Tensor:
"""returns a boolean value indicating whether the notes and pitch contour are or are not all zero, or silent."""
contours_nonsilent = tf.math.reduce_mean(contour) != 0
notes_nonsilent = tf.math.reduce_mean(notes) != 0
return tf.math.logical_or(contours_nonsilent, notes_nonsilent)
Expand Down

0 comments on commit 1a2dfa6

Please sign in to comment.