diff --git a/basic_pitch/data/tf_example_deserialization.py b/basic_pitch/data/tf_example_deserialization.py index b59667d..ae8b70a 100644 --- a/basic_pitch/data/tf_example_deserialization.py +++ b/basic_pitch/data/tf_example_deserialization.py @@ -51,8 +51,16 @@ def prepare_datasets( Return a training and a testing dataset. Args: - training_shuffle_buffer_size : size of shuffle buffer (only for training set) - batch_size : .. + datasets_base_path: path to tfrecords for input data + training_shuffle_buffer_size: size of shuffle buffer (only for training set) + batch_size: batch size for training and validation + validation_steps: number of batches to use for validation + datasets_to_use: the underlying datasets to use for creating training and validation sets e.g. guitarset + dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they + are sampled from during training / validation dataset creation. + + Returns: + training and validation datasets derived from the underlying tfrecord data """ assert batch_size > 0 assert validation_steps is not None and validation_steps > 0 @@ -108,11 +116,18 @@ def prepare_visualization_datasets( dataset_sampling_frequency: np.ndarray, ) -> Tuple[tf.data.Dataset, tf.data.Dataset]: """ - Return a training and a testing dataset. + Return a training and a testing dataset for visualization Args: - training_shuffle_buffer_size : size of shuffle buffer (only for training set) - batch_size : .. + datasets_base_path: path to tfrecord datasets for input data + batch_size: batch size for training and validation + validation_steps: number of batches to use for validation + datasets_to_use: the underlying datasets to use for creating training and validation sets e.g. guitarset + dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they + are sampled from during training / validation dataset creation. + + Returns: + training and validation datasets derived from the underlying tfrecord data """ assert batch_size > 0 @@ -162,6 +177,21 @@ def sample_datasets( n_samples_per_track: int = N_SAMPLES_PER_TRACK, pairs: bool = False, ) -> tf.data.Dataset: + """samples tfrecord data to create a dataset + + Args: + split: whether to use training or validation data + dataset_base_path: directory storing source datasets as tfrecord files + datasets: names of datasets to sample from e.g. guitarset + dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they + are sampled from during training / validation dataset creation. + n_shuffle: size of shuffle buffer (only used for training ds) + n_samples_per_track: the number of samples to take from a track + pairs: generate pairs of samples from the dataset rather than individual samples + + Returns + dataset of samples + """ if split == Split.validation: n_shuffle = 0 pairs = False @@ -214,8 +244,14 @@ def transcription_file_generator( datasets_base_path: str, sample_weights: np.ndarray, ) -> Tuple[Callable[[], Iterator[tf.Tensor]], bool]: - """ - dataset_names: list of dataset dataset_names + """Reads underlying files and returns file generator + + Args: + split: data split to build generator from + dataset_names: list of dataset_names to use + dataset_base_path: directory storing source datasets as tfrecord files + sample_weights: distribution weighting vector corresponding to datasets determining how they + are sampled from during training / validation dataset creation. """ file_dict = { dataset_name: tf.data.Dataset.list_files( @@ -230,6 +266,7 @@ def transcription_file_generator( def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) -> Iterator[tf.Tensor]: + """file generator for training sets""" x = {k: list(v) for (k, v) in x.items()} keys = list(x.keys()) # shuffle each list @@ -243,6 +280,7 @@ def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) -> def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[tf.Tensor]: + """file generator for validation sets""" x = {k: list(v) for (k, v) in x.items()} # loop until there are no more test files while any(x.values()): @@ -258,6 +296,13 @@ def _validation_file_generator(x: Dict[str, tf.data.Dataset]) -> Iterator[tf.Ten def combine_transcription_examples( a: tf.Tensor, target: Dict[str, tf.Tensor], w: Dict[str, tf.Tensor] ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor], Dict[str, tf.Tensor]]: + """mix pairs together for paired dataset + + Args: + a: audio data + target: target data (onset, notes, contours) + w: weights + """ return ( # mix the audio snippets tf.math.reduce_mean(a, axis=0), @@ -396,6 +441,7 @@ def is_not_bad_shape( notes_onsets_shape: tf.Tensor, _contours_shape: tf.Tensor, ) -> tf.Tensor: + """checks for improper datashape for note values and onsets""" bad_shape = tf.logical_and( tf.shape(notes_values)[0] == 0, tf.shape(notes_onsets_shape)[0] == 2, @@ -404,6 +450,7 @@ def is_not_bad_shape( def sparse2dense(values: tf.Tensor, indices: tf.Tensor, dense_shape: tf.Tensor) -> tf.Tensor: + """converts sparse tensor representation to dense vector""" if tf.rank(indices) != 2 and tf.size(indices) == 0: indices = tf.zeros([0, 1], dtype=indices.dtype) tf.assert_rank(indices, 2) @@ -563,6 +610,17 @@ def trim_time(data: np.ndarray, start: int, duration: int, sr: int) -> tf.Tensor def extract_window( audio: tf.Tensor, onsets: np.ndarray, contour: np.ndarray, notes: np.ndarray, t_start: int ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + """extracts a window of data from the given audio and its associated metadata + + Args: + audio: audio signal + onsets: note onsets of audio signal + contour: pitch on off of audio signal + notes: note on off of audio signal + + Returns: + tuple of windows of each of the inputs + """ # needs a hop size extra of samples for good mel spectrogram alignment audio_trim = trim_time( audio, @@ -665,6 +723,7 @@ def is_not_all_silent_annotations( contour_weight: int, note_weight: int, ) -> tf.Tensor: + """returns a boolean value indicating whether the notes and pitch contour are or are not all zero, or silent.""" contours_nonsilent = tf.math.reduce_mean(contour) != 0 notes_nonsilent = tf.math.reduce_mean(notes) != 0 return tf.math.logical_or(contours_nonsilent, notes_nonsilent)