-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathwandb_atkt_train.py
29 lines (24 loc) · 1.24 KB
/
wandb_atkt_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import argparse
from wandb_train import main
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="assist2015")
parser.add_argument("--model_name", type=str, default="atkt")
parser.add_argument("--emb_type", type=str, default="qid")
parser.add_argument("--save_dir", type=str, default="saved_model")
# parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--fold", type=int, default=0)
parser.add_argument("--dropout", type=float, default=0.2)
parser.add_argument("--skill_dim", type=int, default=256)
parser.add_argument("--answer_dim", type=int, default=96)
parser.add_argument("--hidden_dim", type=int, default=80)
parser.add_argument("--attention_dim", type=int, default=80)
parser.add_argument("--epsilon", type=int, default=10)
parser.add_argument("--beta", type=float, default=0.2)
parser.add_argument("--learning_rate", type=float, default=1e-3)
parser.add_argument("--use_wandb", type=int, default=1)
parser.add_argument("--add_uuid", type=int, default=1)
args = parser.parse_args()
params = vars(args)
main(params)