Skip to content

Commit

Permalink
Added typing
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuguy96 committed Dec 2, 2023
1 parent bed231b commit e2786f6
Showing 1 changed file with 102 additions and 52 deletions.
154 changes: 102 additions & 52 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,92 @@
from stepcovnet.training.TrainingHyperparameters import TrainingHyperparameters


def load_training_data(input_path):
metadata = json.load(open(os.path.join(input_path, "metadata.json"), 'r'))
def load_training_data(input_path: str):
metadata = json.load(open(os.path.join(input_path, "metadata.json"), "r"))
dataset_name = metadata["dataset_name"]
dataset_type = ModelDatasetTypes[metadata["dataset_type"]].value
dataset_path = os.path.join(input_path, dataset_name + "_dataset")
scalers = joblib.load(open(os.path.join(input_path, dataset_name + "_scaler.pkl"), 'rb'))
scalers = joblib.load(
open(os.path.join(input_path, dataset_name + "_scaler.pkl"), "rb")
)
dataset_config = metadata["config"]
return dataset_path, dataset_type, scalers, dataset_config


def run_training(input_path, output_path, model_name, limit, lookback, difficulty, log_path):
def run_training(
input_path: str,
output_path: str,
model_name: str,
limit: int,
lookback: int,
difficulty: str,
log_path: str,
):
dataset_path, dataset_type, scalers, dataset_config = load_training_data(input_path)

hyperparameters = TrainingHyperparameters(log_path=log_path)
training_config = TrainingConfig(dataset_path=dataset_path, dataset_type=dataset_type,
dataset_config=dataset_config, hyperparameters=hyperparameters,
all_scalers=scalers, limit=limit, lookback=lookback, difficulty=difficulty,
tokenizer_name=Tokenizers.GPT2.name)
training_config = TrainingConfig(
dataset_path=dataset_path,
dataset_type=dataset_type,
dataset_config=dataset_config,
hyperparameters=hyperparameters,
all_scalers=scalers,
limit=limit,
lookback=lookback,
difficulty=difficulty,
tokenizer_name=Tokenizers.GPT2.name,
)
training_input = TrainingInput(training_config)

arrow_model = GPT2ArrowModel(training_input.config)
audio_model = VggishAudioModel(training_input.config)
classifier_model = ClassifierModel(training_input.config, arrow_model, audio_model)
stepcovnet_model = StepCOVNetModel(model_root_path=output_path, model_name=model_name, model=classifier_model.model)

TrainingExecutor(stepcovnet_model=stepcovnet_model).execute(input_data=training_input)


def train(input_path, output_path, difficulty_int, lookback, limit, name, log_path):
stepcovnet_model = StepCOVNetModel(
model_root_path=str(output_path),
model_name=model_name,
model=classifier_model.model,
)

TrainingExecutor(stepcovnet_model=stepcovnet_model).execute(
input_data=training_input
)


def train(
input_path: str,
output_path: str,
difficulty_int: int,
lookback: int,
limit: int,
name: str,
log_path: str,
):
if not os.path.isdir(input_path):
raise NotADirectoryError('Input path %s not found' % os.path.abspath(input_path))
raise NotADirectoryError(
"Input path %s not found" % os.path.abspath(input_path)
)

if not os.path.isdir(output_path):
print("Model output path not found. Creating directory...")
os.makedirs(output_path, exist_ok=True)

if lookback <= 1:
raise ValueError('Lookback needs to be > 1')
raise ValueError("Lookback needs to be > 1")

if limit == 0:
raise ValueError('Limit cannot be = 0')
raise ValueError("Limit cannot be = 0")

if name is not None and not name:
raise ValueError('Model name cannot be empty')
raise ValueError("Model name cannot be empty")

if log_path is not None and not os.path.isdir(log_path):
print("Log output path not found. Creating directory...")
os.makedirs(log_path, exist_ok=True)

if log_path is not None:
log_path = os.path.join(log_path, "tensorboard", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
log_path = os.path.join(
log_path, "tensorboard", datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
)

difficulty = ["challenge", "hard", "medium", "easy", "beginner"][difficulty_int]

Expand All @@ -81,44 +116,59 @@ def train(input_path, output_path, difficulty_int, lookback, limit, name, log_pa
output_path = os.path.join(output_path, model_name)
os.makedirs(output_path, exist_ok=True)

run_training(input_path=input_path, output_path=output_path, model_name=model_name, limit=limit, lookback=lookback,
difficulty=difficulty, log_path=log_path)
run_training(
input_path=input_path,
output_path=output_path,
model_name=model_name,
limit=limit,
lookback=lookback,
difficulty=difficulty,
log_path=log_path,
)


if __name__ == '__main__':
if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Train a note timings model")

parser.add_argument("-i", "--input",
type=str,
help="Input training data path")
parser.add_argument("-o", "--output",
type=str,
help="Output stored model path")
parser.add_argument("-d", "--difficulty",
type=int,
default=0,
choices=[0, 1, 2, 3, 4],
help="Game difficulty to use when training: 0 - challenge, 1 - hard, 2 - medium, 3 - easy, 4, "
"- beginner")
parser.add_argument("--lookback",
type=int,
default=2,
help="Number of frames to lookback when training: 1 - non timeseries, > 1 timeseries")
parser.add_argument("--limit",
type=int,
default=-1,
help="Maximum number of frames to use when training: -1 unlimited, > 0 frame limit")
parser.add_argument("--name",
type=str,
default=None,
help="Name to give finished model")
parser.add_argument("--log",
type=str,
default=None,
help="Output log data path for tensorboard")
parser.add_argument("-i", "--input", type=str, help="Input training data path")
parser.add_argument("-o", "--output", type=str, help="Output stored model path")
parser.add_argument(
"-d",
"--difficulty",
type=int,
default=0,
choices=[0, 1, 2, 3, 4],
help="Game difficulty to use when training: 0 - challenge, 1 - hard, 2 - medium, 3 - easy, 4, "
"- beginner",
)
parser.add_argument(
"--lookback",
type=int,
default=2,
help="Number of frames to lookback when training: 1 - non timeseries, > 1 timeseries",
)
parser.add_argument(
"--limit",
type=int,
default=-1,
help="Maximum number of frames to use when training: -1 unlimited, > 0 frame limit",
)
parser.add_argument(
"--name", type=str, default=None, help="Name to give finished model"
)
parser.add_argument(
"--log", type=str, default=None, help="Output log data path for tensorboard"
)
args = parser.parse_args()

train(input_path=args.input, output_path=args.output, difficulty_int=args.difficulty,
lookback=args.lookback, limit=args.limit, name=args.name, log_path=args.log)
train(
input_path=args.input,
output_path=args.output,
difficulty_int=args.difficulty,
lookback=args.lookback,
limit=args.limit,
name=args.name,
log_path=args.log,
)

0 comments on commit e2786f6

Please sign in to comment.