Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mcleavey committed Sep 12, 2018
1 parent 6420f3b commit ac19d99
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions make_critic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
TRAIN=Path("./critic_data/train")

def make_critic_data(num_to_generate, replace, prefix, model_to_load, training, gen_size, use_test_prompt, generator_bs, tt_split):
'''
Generates samples of real and fake data. Files are written to:
critic_data/test/fake, critic_data/test/real, critic_data/train/real, critic_data/test/fake
Inputs:
num_to_generate - the number of fake pieces to create. These are then split into files bptt long (according to the backprop size of the model)
replace - if true, clear the existing files in the critic_data folders and replace with these new files
prefix - add prefix to generated files (needed if not replacing the old data)
model_to_load - model for generating fake pieces
training - whether to use light, med, full, or extra training on those models
gen_size - fake piece generation size (pieces are often better in the beginning and sound worse as they progress for a long time)
use_test_prompt - generate fake pieces from test prompts (pieces the model hasn't seen before)
generator_bs - number of pieces to generate in parallel
tt_split - test/train split (fraction of pieces to put into the test folder)
'''

PATHS=create_paths()

# Load pretrained model and training/test text
Expand Down Expand Up @@ -43,7 +59,7 @@ def make_critic_data(num_to_generate, replace, prefix, model_to_load, training,
trunc_size=random.randint(1,10), prompts=prompts,
params=params, TEXT=TEXT)

# Write to train/real and train/fake, or test/real and test/fake
# Write to train/fake and test/fake
# Choose randomly whether train or test, according to the test_train_split (tt_split) frequency

num_samples=0
Expand All @@ -58,6 +74,8 @@ def make_critic_data(num_to_generate, replace, prefix, model_to_load, training,
f.close()
num_samples+=1

# Pull human-composed samples randomly from the prompts. Add to train/real and test/real
# Choose randomly whether train or test, according to the test_train_split (tt_split) frequency
musical_prompts=generate_musical_prompts(prompts, bptt, num_samples)
for i in range(num_samples):
dest=TEST if random.random()<tt_split else TRAIN
Expand All @@ -69,7 +87,7 @@ def make_critic_data(num_to_generate, replace, prefix, model_to_load, training,
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-model", help="Trained model in ./data/models", required=True)
parser.add_argument("-num", help="Number of files to generate (default 1000)", type=int)
parser.add_argument("--num", help="Number of files to generate (default 1000)", type=int)
parser.set_defaults(num=1000)
parser.add_argument("--replace", dest="replace", action="store_true", help="Overwrite existing test/train critic data")
parser.set_defaults(replace=False)
Expand Down

0 comments on commit ac19d99

Please sign in to comment.