-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembeddings.py
95 lines (66 loc) · 2.78 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from abc import ABC, abstractmethod
import numpy as np
class AbstractMergeEmbeddings(ABC):
@abstractmethod
def should_duplicate_labels(self):
"""Determine if labels should be duplicated."""
pass
@abstractmethod
def unpack_embeddings(self, df_embeddings):
"""Unpack and process embeddings from the provided DataFrame."""
pass
@abstractmethod
def unpack_embeddings_test(self, df_embeddings):
"""Unpack and process embeddings for testing purposes."""
pass
@abstractmethod
def __str__(self):
"""Provide a meaningful string representation of the instance."""
pass
def __repr__(self):
return self.__str__()
def create_consecutive_groups(n_samples, n_groups):
group_size = n_samples // n_groups
remainder = n_samples % n_groups
groups = [i for i in range(n_groups) for _ in range(group_size)]
if remainder:
groups.extend([n_groups - 1] * remainder)
return groups
class AbstractNumpyMergeEmbeddings(AbstractMergeEmbeddings):
@abstractmethod
def merge_embeddings(self, embeddings_1, embeddings_2):
pass
def unpack_embeddings(self, df_embeddings):
rows = [self.merge_embeddings(x1.flatten(), x2.flatten()) for x1, x2 in zip(df_embeddings['emb_prot1'], df_embeddings['emb_prot2'])]
n_samples = len(rows)
groups = create_consecutive_groups(len(rows), n_samples)
return np.array(rows), groups
def unpack_embeddings_test(self, df_embeddings):
return self.unpack_embeddings(df_embeddings)
def should_duplicate_labels(self):
return False
class AddEmbeddings(AbstractNumpyMergeEmbeddings):
def __str__(self):
return f'add'
def merge_embeddings(self, embeddings_1, embeddings_2):
return np.add(embeddings_1, embeddings_2)
class MultiplyEmbeddings(AbstractNumpyMergeEmbeddings):
def __str__(self):
return f'multiply'
def merge_embeddings(self, embeddings_1, embeddings_2):
return np.multiply(embeddings_1, embeddings_2)
class DiffEmbeddings(AbstractNumpyMergeEmbeddings):
def __str__(self):
return f'diff'
def merge_embeddings(self, embeddings_1, embeddings_2):
return np.abs(np.subtract(embeddings_1, embeddings_2))
class ConcatEmbeddings(AbstractNumpyMergeEmbeddings):
def __str__(self):
return f'concat_'
def merge_embeddings(self, embeddings_1, embeddings_2):
return np.concatenate((embeddings_1, embeddings_2))
def unpack_embeddings(self, df_embeddings):
rows, groups = AbstractNumpyMergeEmbeddings.unpack_embeddings(self, df_embeddings)
return rows, groups
def unpack_embeddings_test(self, df_embeddings):
return AbstractNumpyMergeEmbeddings.unpack_embeddings(self, df_embeddings)