Skip to content

Commit

Permalink
Used direct Keras imports
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuguy96 committed Dec 3, 2023
1 parent 88d3fd5 commit 79b5088
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 26 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ soundfile~=0.12.1
resampy~=0.4.2
transformers~=4.35.2

tensorflow==2.15.0
tensorflow==2.15.0
keras~=2.15.0
28 changes: 17 additions & 11 deletions stepcovnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import tensorflow as tf
import transformers
from tensorflow.python.keras.initializers import he_uniform, Constant, glorot_uniform
from tensorflow.python.keras.layers import (
from keras import initializers
from keras.layers import (
Bidirectional,
LSTM,
Conv2D,
Expand All @@ -24,7 +24,7 @@
Input,
Layer,
)
from tensorflow.python.keras.models import load_model, Model
from keras.models import load_model, Model
from transformers import GPT2Config, TFGPT2Model

from stepcovnet import config, constants
Expand Down Expand Up @@ -106,8 +106,8 @@ def __init__(
feature_concat = concatenate([arrow_model.output, audio_model.output])
model = Dense(
256,
kernel_initializer=tf.keras.initializers.he_uniform(42),
bias_initializer=tf.keras.initializers.Zeros(),
kernel_initializer=initializers.HeUniform(42),
bias_initializer=initializers.Zeros(),
)(feature_concat)
model = BatchNormalization()(model)
model = Activation("relu")(model)
Expand All @@ -116,8 +116,10 @@ def __init__(
model_output = Dense(
constants.NUM_ARROW_COMBS,
activation="softmax",
bias_initializer=Constant(value=training_config.init_bias_correction),
kernel_initializer=glorot_uniform(42),
bias_initializer=initializers.Constant(
value=training_config.init_bias_correction
),
kernel_initializer=initializers.GlorotUniform(42),
dtype=tf.float32,
name="onehot_encoded_arrows",
)(model)
Expand Down Expand Up @@ -411,7 +413,7 @@ def load(
class VggishAudioModel(AudioModel):
def _create_audio_model(
self, training_config: config.TrainingConfig, model_input: Input
) -> tf.keras.layers.Layer:
) -> Layer:
# Channel reduction
if training_config.dataset_config["NUM_CHANNELS"] > 1:
vggish_input = TimeDistributed(
Expand All @@ -421,8 +423,8 @@ def _create_audio_model(
strides=(1, 1),
activation="linear",
padding="same",
kernel_initializer=he_uniform(42),
bias_initializer=tf.keras.initializers.Zeros(),
kernel_initializer=initializers.HeUniform(42),
bias_initializer=initializers.Zeros(),
image_shape=model_input.shape[1:],
data_format="channels_last",
name="channel_reduction",
Expand All @@ -440,5 +442,9 @@ def _create_audio_model(
# VGGish model returns feature maps for avg/max pooling. Using LSTM for additional feature extraction.
# Might be able to replace this with another method in the future
return Bidirectional(
LSTM(128, return_sequences=False, kernel_initializer=glorot_uniform(42))
LSTM(
128,
return_sequences=False,
kernel_initializer=initializers.GlorotUniform(42),
)
)(model_output)
10 changes: 4 additions & 6 deletions stepcovnet/tf_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
from keras import mixed_precision

# tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)

MIXED_PRECISION_POLICY = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(MIXED_PRECISION_POLICY)
MIXED_PRECISION_POLICY = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(MIXED_PRECISION_POLICY)

tf.config.optimizer.set_jit(True)

Expand Down
16 changes: 8 additions & 8 deletions stepcovnet/training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict

import numpy as np
import tensorflow as tf
from keras import metrics, losses, optimizers

from stepcovnet import data, utils

Expand All @@ -11,14 +11,14 @@ class TrainingHyperparameters:

# TODO(https://github.com/cpuguy96/StepCOVNet/issues/2): Move all training hyperparameters into config file
DEFAULT_METRICS = [
tf.keras.metrics.CategoricalAccuracy(name="acc"),
tf.keras.metrics.Precision(name="pre"),
tf.keras.metrics.Recall(name="rec"),
tf.keras.metrics.AUC(curve="PR", name="pr_auc"),
tf.keras.metrics.AUC(name="auc"),
metrics.CategoricalAccuracy(name="acc"),
metrics.Precision(name="pre"),
metrics.Recall(name="rec"),
metrics.AUC(curve="PR", name="pr_auc"),
metrics.AUC(name="auc"),
]
DEFAULT_LOSS = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.05)
DEFAULT_OPTIMIZER = tf.keras.optimizers.Nadam(beta_1=0.99)
DEFAULT_LOSS = losses.CategoricalCrossentropy(label_smoothing=0.05)
DEFAULT_OPTIMIZER = optimizers.Nadam(beta_1=0.99)
DEFAULT_EPOCHS = 15
DEFAULT_PATIENCE = 3
DEFAULT_BATCH_SIZE = 32
Expand Down

0 comments on commit 79b5088

Please sign in to comment.