Skip to content

Commit

Permalink
Refactor model loading (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
kahrendt authored Mar 10, 2024
1 parent 9fe3709 commit 332e1b0
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 217 deletions.
2 changes: 1 addition & 1 deletion microwakeword/feature_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def create_fixed_size_clip(self, x, sr=16000):

assert max_samples_from_end > len(x)

samples_from_end = np.random.randint(len(x), max_samples_from_end + 1)
samples_from_end = np.random.randint(len(x), max_samples_from_end) + 1

dat[-samples_from_end : -samples_from_end + len(x)] = x

Expand Down
28 changes: 24 additions & 4 deletions microwakeword/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


from microwakeword.layers import delay
from microwakeword.layers import modes
from microwakeword.layers import stream
from microwakeword.layers import strided_drop
from microwakeword.layers import sub_spectral_normalization
Expand Down Expand Up @@ -210,7 +209,28 @@ def model_parameters(parser_nn):
)


def model(flags, config):
def spectrogram_slices_dropped(flags):
"""Computes the number of spectrogram slices dropped due to valid padding.
Args:
flags: data/model parameters
Returns:
int: number of spectrogram slices dropped
"""
spectrogram_slices_dropped = 0

for kernel_size in parse(flags.cnn1_kernel_sizes):
spectrogram_slices_dropped += kernel_size - 1
for kernel_size, dilation in zip(
parse(flags.cnn2_kernel_sizes), parse(flags.cnn2_dilation)
):
spectrogram_slices_dropped += 2 * dilation * (kernel_size - 1)

return spectrogram_slices_dropped


def model(flags, shape, batch_size):
"""Inception model.
It is based on paper:
Expand All @@ -224,8 +244,8 @@ def model(flags, config):
Keras model for training
"""
input_audio = tf.keras.layers.Input(
shape=modes.get_input_data_shape(config, modes.Modes.TRAINING),
batch_size=config["batch_size"],
shape=shape,
batch_size=batch_size,
)
net = input_audio

Expand Down
Loading

0 comments on commit 332e1b0

Please sign in to comment.