-
Notifications
You must be signed in to change notification settings - Fork 1
/
transformer_configuration.yaml
112 lines (90 loc) · 2.61 KB
/
transformer_configuration.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# general
################################################################################
# experiment files directory
save_directory: experiments
# experiment naming prefix
experiment_prefix: transformer
# PyTorch random number generator initialization seed
random_seed: 5
#random_seed: 7
#random_seed: 11
################################################################################
# dataset
################################################################################
dataset_id: 1pct
#dataset_id: 5pct
#dataset_id: 20pct
#dataset_id: full
# training, validation, test split
test_ratio: 0.15
validation_ratio: 0.15
################################################################################
# features
################################################################################
sequence_length: 1000 # short length for fast dev runs
#sequence_length: 7549 # dataset sequences median length
#sequence_length: 29001 # dataset sequences mean length
#sequence_length: 115019 # 29001 + 86018 : mean + 1 * standard_deviation
#padding_side: left
padding_side: right
################################################################################
# network architecture
################################################################################
embedding_dimension: 16
#embedding_dimension: 32
#embedding_dimension: 64
num_heads: 2
#num_heads: 4
#num_heads: 8
transformer_depth: 1
#transformer_depth: 2
#transformer_depth: 3
activation_function: relu
#activation_function: gelu
feedforward_connections: 64
#feedforward_connections: 128
# L2 regularization
weight_decay: 0
#weight_decay: 1.0e-6
#weight_decay: 1.0e-5
# max norm for gradient clipping
clip_max_norm: 0
#clip_max_norm: 5
dropout_probability: 0
#dropout_probability: 0.1
#dropout_probability: 0.2
################################################################################
# training
################################################################################
batch_size: 4
#batch_size: 8
#batch_size: 16
#batch_size: 32
#batch_size: 64
#num_workers: 0
num_workers: 1
#num_workers: 3
#num_workers: 5
#num_workers: 13
learning_rate: 3.0e-4
#learning_rate: 1.0e-4
#learning_rate: 3.0e-5
# number of epochs without validation loss improvement before training stops
patience: 3
#patience: 5
#patience: 7
# minimum validation loss change to consider as improvement
loss_delta: 0
#loss_delta: 1.0e-6
# maximum number of training epochs
#max_epochs: 1
max_epochs: 3
#max_epochs: 10
#max_epochs: 100
#max_epochs: 1000
gpus:
#gpus: 1
profiler:
#profiler: simple
#profiler: pytorch
################################################################################