-
Notifications
You must be signed in to change notification settings - Fork 2
/
average_ckpt.sh
39 lines (38 loc) · 1.05 KB
/
average_ckpt.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
LOCAL_NUM=2
ckpt_pth=log/ablation/diff_only_relative$LOCAL_NUM/
avg_path=log/ablation/diff_only_relative$LOCAL_NUM/average
mkdir $avg_path
python -m average_ckpt \
--inputs $ckpt_pth \
--output $avg_path/average_ckpt.pt \
--num-epoch-checkpoints 5 \
--checkpoint-upper-bound 42
python -m train \
--mode "test" \
--log_dir $ckpt_pth \
--data_path "data_bin/PHOENIX2014T" \
--embedding_dim 512 \
--hidden_size 512 \
--num_heads 8 \
--num_layers 6 \
--local_num_layers $LOCAL_NUM \
--max_relative_positions 16 \
--norm_type "batch" \
--activation_type "softsign" \
--reg_loss_weight 1.0 \
--tran_loss_weight 1.0 \
--label_smoothing 0.0 \
--optimizer "adam" \
--slr_learning_rate 0.001 \
--slt_learning_rate 0.001 \
--weight_decay 0.0001\
--decrease_factor 0.5 \
--patience 3 \
--reg_beam_size 5 \
--max_output_length 500 \
--text_beam_size 5 \
--alpha 2 \
--batch_size 5 \
--check_point $avg_path/average_ckpt.pt \
--max_epoch 100 \
--print_step 100