Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated typing to Python 3.10 #26

Merged
merged 1 commit into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 49 additions & 30 deletions stepcovnet/config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from abc import ABC
from typing import Type, Union
from typing import Type

import numpy as np
from sklearn import preprocessing
from sklearn.model_selection import train_test_split

from stepcovnet import dataset, training, constants, utils


class AbstractConfig(ABC, object):
def __init__(self, dataset_config, lookback, difficulty, *args, **kwargs):
class AbstractConfig(ABC):
def __init__(
self, dataset_config: dict, lookback: int, difficulty: str, *args, **kwargs
):
self.dataset_config = dataset_config
self.lookback = lookback
self.difficulty = difficulty

@property
def arrow_input_shape(self):
def arrow_input_shape(self) -> tuple[None,]:
return (None,)

@property
def arrow_mask_shape(self):
def arrow_mask_shape(self) -> tuple[None,]:
return (None,)

@property
def audio_input_shape(self):
def audio_input_shape(self) -> tuple[int, int, int, int]:
return (
self.lookback,
self.dataset_config["NUM_TIME_BANDS"],
Expand All @@ -31,13 +34,19 @@ def audio_input_shape(self):
)

@property
def label_shape(self):
def label_shape(self) -> tuple[int,]:
return (constants.NUM_ARROW_COMBS,)


class InferenceConfig(AbstractConfig):
def __init__(
self, audio_path, file_name, dataset_config, lookback, difficulty, scalers=None
self,
audio_path: str,
file_name: str,
dataset_config: dict,
lookback: int,
difficulty: str,
scalers: list[preprocessing.StandardScaler] | None = None,
):
super(InferenceConfig, self).__init__(
dataset_config=dataset_config, lookback=lookback, difficulty=difficulty
Expand All @@ -52,9 +61,9 @@ def __init__(
self,
dataset_path: str,
dataset_type: Type[dataset.ModelDataset],
dataset_config,
dataset_config: dict,
hyperparameters: training.TrainingHyperparameters,
all_scalers=None,
all_scalers: list[preprocessing.StandardScaler] | None = None,
limit: int = -1,
lookback: int = 1,
difficulty: str = "challenge",
Expand Down Expand Up @@ -85,13 +94,15 @@ def __init__(
self.init_bias_correction = self.get_init_bias_correction()
self.train_scalers = self.get_train_scalers()

def get_train_val_split(self) -> Union[np.array, np.array, np.array]:
def get_train_val_split(
self,
) -> tuple[np.ndarray[int], np.ndarray[int], np.ndarray[int]]:
all_indexes = []
with self.enter_dataset as dataset:
with self.enter_dataset as model_dataset:
total_samples = 0
index = 0
for song_start_index, song_end_index in dataset.song_index_ranges:
if not any(dataset.labels[song_start_index:song_end_index] < 0):
for song_start_index, song_end_index in model_dataset.song_index_ranges:
if not any(model_dataset.labels[song_start_index:song_end_index] < 0):
all_indexes.append(index)
total_samples += song_end_index - song_start_index
if 0 < self.limit < total_samples:
Expand All @@ -103,12 +114,14 @@ def get_train_val_split(self) -> Union[np.array, np.array, np.array]:
)
return all_indexes, train_indexes, val_indexes

def get_class_weights(self, indexes) -> dict:
def get_class_weights(self, indexes: np.ndarray[int]) -> dict:
labels = None
with self.enter_dataset as dataset:
with self.enter_dataset as model_dataset:
for index in indexes:
song_start_index, song_end_index = dataset.song_index_ranges[index]
encoded_arrows = dataset.onehot_encoded_arrows[
song_start_index, song_end_index = model_dataset.song_index_ranges[
index
]
encoded_arrows = model_dataset.onehot_encoded_arrows[
song_start_index:song_end_index
]
if labels is None:
Expand All @@ -134,40 +147,46 @@ def get_class_weights(self, indexes) -> dict:

return dict(enumerate(class_weights))

def get_init_bias_correction(self) -> np.ndarray:
def get_init_bias_correction(self) -> float:
# Best practices mentioned in
# https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#optional_set_the_correct_initial_bias
# Not completely correct but works for now
num_all = self.num_train_samples
num_pos = 0
with self.enter_dataset as dataset:
with self.enter_dataset as model_dataset:
for index in self.train_indexes:
song_start_index, song_end_index = dataset.song_index_ranges[index]
num_pos += dataset.labels[song_start_index:song_end_index].sum()
song_start_index, song_end_index = model_dataset.song_index_ranges[
index
]
num_pos += model_dataset.labels[song_start_index:song_end_index].sum()
num_neg = num_all - num_pos
return np.log(num_pos / num_neg)

def get_train_scalers(self):
def get_train_scalers(self) -> list | None:
training_scalers = None
with self.enter_dataset as dataset:
with self.enter_dataset as model_dataset:
for index in self.train_indexes:
song_start_index, song_end_index = dataset.song_index_ranges[index]
features = dataset.features[song_start_index:song_end_index]
song_start_index, song_end_index = model_dataset.song_index_ranges[
index
]
features = model_dataset.features[song_start_index:song_end_index]
training_scalers = utils.get_channel_scalers(
features, existing_scalers=training_scalers
)
return training_scalers

def get_num_samples(self, indexes) -> int:
def get_num_samples(self, indexes: np.ndarray[int]) -> int:
num_all = 0
with self.enter_dataset as dataset:
with self.enter_dataset as model_dataset:
for index in indexes:
song_start_index, song_end_index = dataset.song_index_ranges[index]
song_start_index, song_end_index = model_dataset.song_index_ranges[
index
]
num_all += song_end_index - song_start_index
return num_all

@property
def enter_dataset(self):
def enter_dataset(self) -> dataset.ModelDataset:
return self.dataset_type(
self.dataset_path, difficulty=self.difficulty
).__enter__()
10 changes: 7 additions & 3 deletions stepcovnet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

NUM_TIME_BANDS = 15

THRESHOLDS = {'expert': 0.5}
THRESHOLDS = {"expert": 0.5}

NUM_ARROW_TYPES = 4 # TODO: Move this to dataset config

Expand All @@ -23,15 +23,19 @@
NUM_ARROWS = len(ARROW_NAMES)


def get_all_note_combs(num_note_types: int) -> np.ndarray:
def get_all_note_combs(num_note_types: int) -> list[str]:
all_note_combs = []

for first_digit in range(0, num_note_types):
for second_digit in range(0, num_note_types):
for third_digit in range(0, num_note_types):
for fourth_digit in range(0, num_note_types):
all_note_combs.append(
str(first_digit) + str(second_digit) + str(third_digit) + str(fourth_digit))
str(first_digit)
+ str(second_digit)
+ str(third_digit)
+ str(fourth_digit)
)
# Adding '0000' to possible note combinations.
# This will allow the arrow prediction model to predict an empty note.
return all_note_combs
Expand Down
Loading