-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalibibi_positional_encoder.py
122 lines (97 loc) · 3.86 KB
/
alibibi_positional_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import math
import torch
from torch import Tensor
import matplotlib.pyplot as plt
import numpy as np
class ALiBiBiEncoder:
@staticmethod
def get_slopes(n_heads: int):
n = 2 ** math.floor(math.log2(n_heads))
m_0 = 2.0 ** (-8.0 / n)
m = torch.pow(m_0, torch.arange(1, 1 + n, step=n / n_heads))
return m
@staticmethod
def generate_reverse_distance_matrix(dimension: int, left_negative: True):
matrix = torch.zeros((dimension, dimension), dtype=torch.float)
multiplier = -1 if left_negative else 1
for i in range(dimension):
for j in range(dimension):
if i == j:
matrix[i, j] = dimension
elif i < j:
matrix[i, j] = dimension - abs(i - j)
elif i > j:
matrix[i, j] = (dimension - abs(i - j)) * multiplier
return matrix
@staticmethod
def generate_distance_matrix(dim, with_mask=False):
matrix = torch.zeros(dim, dim)
for i in range(dim):
for j in range(dim):
if i == j:
matrix[i, j] = 0
elif i < j:
matrix[i, j] = 0 if with_mask else (j - i)
else:
matrix[i, j] = i - j
return matrix
@staticmethod
@torch.no_grad()
def get_alibi_biases(
batch_size: int,
n_heads: int,
sequence_length: int,
with_mask=False,
device="cpu",
) -> Tensor:
"""This bias will be subtracted from the real attention score
No need to use positional information in cross-attention
Link: https://github.com/ofirpress/attention_with_linear_biases/issues/5
:param batch_size: Batch size that is used for training.
:param n_heads: Number of head in multi head attention.
:param sequence_length: Max supported sequence length in a sentence.
:param with_mask: Hide future words with mask or not.
:param device: device type.
:return: Bias tensor which will be subtracted from attention score.
"""
m = ALiBiBiEncoder.get_slopes(n_heads).to(device)
distance_matrix = ALiBiBiEncoder.generate_distance_matrix(
sequence_length, with_mask
)
# Multiply them pair-wise to get the AliBi bias matrix
bias = distance_matrix[None, :, :] * m[:, None, None]
# Broadcasting the last layer to match the batch_size
bias_expanded_to_batch_size = bias.unsqueeze(0).expand(batch_size, -1, -1, -1)
return bias_expanded_to_batch_size
def _test_alibi():
slopes = ALiBiBiEncoder.get_slopes(8)
print(f"slopes list for MHA {slopes}")
distance_matrix = ALiBiBiEncoder.generate_distance_matrix(5)
print(f"distance matrix shape {distance_matrix.shape} \n {distance_matrix}")
bias = ALiBiBiEncoder.get_alibi_biases(
batch_size=1, n_heads=8, sequence_length=5, with_mask=False
)
print(f"bias shape {bias.shape} \n {bias}")
# Convert tensor to NumPy array
array_data = bias[0].numpy()
# Create a 4x2 grid of subplots
fig, axes = plt.subplots(2, 4, figsize=(24, 10))
# for i in range(array_data.shape[0]):
# plt.imshow(array_data[i], cmap='viridis', interpolation='nearest')
# plt.colorbar()
# plt.show()
for i, ax in enumerate(axes.flat):
im = ax.imshow(array_data[i], cmap='viridis', interpolation='nearest')
ax.set_title(f'Attention head {i + 1}')
cbar = fig.colorbar(im, ax=ax)
# cbar.set_label('Value') #
# ax.axis('off') # Turn off axis labels
# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()
# Create the heatmap plot
# plt.imshow(array_data, cmap='viridis', interpolation='nearest')
# plt.colorbar()
# plt.show()
if __name__ == "__main__":
_test_alibi()