-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_augmenters.py
69 lines (53 loc) · 1.94 KB
/
data_augmenters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import random
MIX_UP_PROB_LENGTH = 3
MIX_UP_TYPES = {
'sentence': 'sentence_mixup',
'word': 'word_mixup',
'manifold': 'manifold_mixup',
}
def get_manifold_mix_up_dict_and_label(input_id, label):
layers = [i for i in range(MIX_UP_PROB_LENGTH)]
lam = np.random.beta(2.0, 2.0)
indices = [i for i in range(input_id.size()[0])]
random.shuffle(indices)
mix_up_dict = dict()
mix_up_dict['layer_num'] = random.choice(layers),
mix_up_dict['lam'] = lam
mix_up_dict["shuffled_indices"] = indices
shuffled_label = label[indices]
label = (lam * label + (1 - lam) * shuffled_label).long()
return mix_up_dict, label
def get_word_mix_up_dict_and_label(input_id, label):
lam = np.random.beta(2.0, 2.0)
indices = [i for i in range(input_id.size()[0])]
random.shuffle(indices)
mix_up_dict = dict()
mix_up_dict['layer_num'] = 0,
mix_up_dict['lam'] = lam
mix_up_dict["shuffled_indices"] = indices
shuffled_label = label[indices]
label = (lam * label + (1 - lam) * shuffled_label).long()
return mix_up_dict, label
def get_sentence_mix_up_dict_and_label(input_id, label):
lam = np.random.beta(2.0, 2.0)
indices = [i for i in range(input_id.size()[0])]
random.shuffle(indices)
mix_up_dict = dict()
mix_up_dict['layer_num'] = MIX_UP_PROB_LENGTH - 1,
mix_up_dict['lam'] = lam
mix_up_dict["shuffled_indices"] = indices
shuffled_label = label[indices]
label = (lam * label + (1 - lam) * shuffled_label).long()
return mix_up_dict, label
def mix_up(x, current_layer, mix_up_dict, training):
if training:
layer_num = mix_up_dict['layer_num']
if current_layer == layer_num:
shuffled_indices = mix_up_dict['shuffled_indices']
lam = mix_up_dict['lam']
x_shuffled = x[shuffled_indices]
mixed_x = lam * x + (1 - lam) * x_shuffled
return mixed_x
return x
return x