-
Notifications
You must be signed in to change notification settings - Fork 310
/
sweep.py
39 lines (32 loc) · 1.26 KB
/
sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import subprocess
import sys
from typing import List, Union
import fire
def main(model_sizes: Union[List[str], str], **kwargs):
if isinstance(model_sizes, str):
model_sizes = model_sizes.split(",")
assert (
"weak_model_size" not in kwargs
and "model_size" not in kwargs
and "weak_labels_path" not in kwargs
), "Need to use model_sizes when using sweep.py"
basic_args = [sys.executable, os.path.join(os.path.dirname(__file__), "train_simple.py")]
for key, value in kwargs.items():
basic_args.extend([f"--{key}", str(value)])
print("Running ground truth models")
for model_size in model_sizes:
subprocess.run(basic_args + ["--model_size", model_size], check=True)
print("Running transfer models")
for i in range(len(model_sizes)):
for j in range(i, len(model_sizes)):
weak_model_size = model_sizes[i]
strong_model_size = model_sizes[j]
print(f"Running weak {weak_model_size} to strong {strong_model_size}")
subprocess.run(
basic_args
+ ["--weak_model_size", weak_model_size, "--model_size", strong_model_size],
check=True,
)
if __name__ == "__main__":
fire.Fire(main)