Skip to content

Commit

Permalink
Up-directory sample_collection_helper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuguy96 committed Dec 2, 2023
1 parent 562e94b commit 96a3a81
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 10 deletions.
Empty file.
5 changes: 2 additions & 3 deletions stepcovnet/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import numpy as np
import tensorflow as tf

from stepcovnet import config, training
from stepcovnet import config, training, sample_collection_helper
from stepcovnet.common.utils import get_samples_ngram_with_mask
from stepcovnet.data_collection.sample_collection_helper import get_audio_features


class AbstractInput(ABC, object):
Expand All @@ -16,7 +15,7 @@ def __init__(self, input_config, *args, **kwargs):
class InferenceInput(AbstractInput):
def __init__(self, inference_config: config.InferenceConfig):
super(InferenceInput, self).__init__(input_config=inference_config)
self.audio_features = get_audio_features(
self.audio_features = sample_collection_helper.get_audio_features(
wav_path=self.config.audio_path,
file_name=self.config.file_name,
config=self.config.dataset_config,
Expand Down
File renamed without changes.
12 changes: 5 additions & 7 deletions training_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@
import joblib
import psutil

from stepcovnet import data
from stepcovnet import data, sample_collection_helper
from stepcovnet.common.parameters import CONFIG, VGGISH_CONFIG
from stepcovnet.common.utils import (
get_channel_scalers,
get_filename,
get_filenames_from_folder,
)
from stepcovnet.data_collection.sample_collection_helper import (
feature_onset_phrase_label_sample_weights,
get_features_and_labels,
)


def build_all_metadata(**kwargs):
Expand Down Expand Up @@ -50,7 +46,9 @@ def collect_features(wav_path, timing_path, config, cores, file_name):
binary_encoded_arrows,
string_arrows,
onehot_encoded_arrows,
) = get_features_and_labels(wav_path, timing_path, file_name, config)
) = sample_collection_helper.get_features_and_labels(
wav_path, timing_path, file_name, config
)
(
feature,
label_dict,
Expand All @@ -60,7 +58,7 @@ def collect_features(wav_path, timing_path, config, cores, file_name):
binary_encoded_arrows_dict,
string_arrows_dict,
onehot_encoded_arrows_dict,
) = feature_onset_phrase_label_sample_weights(
) = sample_collection_helper.feature_onset_phrase_label_sample_weights(
onsets,
log_mel,
arrows,
Expand Down

0 comments on commit 96a3a81

Please sign in to comment.