forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
90 lines (73 loc) · 2.8 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded
# for benchmarking and some control flow stripped out.
# https://github.com/mlperf/training/blob/master/rnn_translator/pytorch/seq2seq/models/attention.py
from . import benchmark
import torch
class BahdanauAttention(benchmark.Benchmark):
def __init__(self, mode, device, dtype, b, t_q, t_k, n):
super().__init__(mode, device, dtype)
self.b = b
self.t_q = t_q
self.t_k = t_k
self.n = n
self.att_query = self.rand(
[b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.att_keys = self.rand(
[b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.normalize_bias = self.rand(
[n], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.linear_att = self.rand(
[n], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.inputs = [
self.att_query,
self.att_keys,
self.normalize_bias,
self.linear_att,
]
def forward(self, att_query, att_keys, normalize_bias, linear_att):
"""
Calculate Bahdanau score
:param att_query: b x t_q x n
:param att_keys: b x t_k x n
return b x t_q x t_k scores
"""
b, t_k, n = att_keys.size()
t_q = att_query.size(1)
att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
sum_qk = att_query + att_keys + normalize_bias
out = torch.tanh(sum_qk).matmul(linear_att)
return out
def reference(self):
return self.numpy(self.forward(*self.inputs))
def config(self):
return [self.b, self.t_q, self.t_k, self.n]
@staticmethod
def module():
return "attention"
def memory_workload(self):
def memsize(t):
return t.numel() * t.element_size()
input_size = (
memsize(self.att_query)
+ memsize(self.att_keys)
+ memsize(self.normalize_bias)
+ memsize(self.linear_att)
)
output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel()
io_size = input_size + output_size
# If matmul is not fused, must write and then read `sum_qk`.
intermediate_size = (
2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel()
)
return {"sol": io_size, "algorithmic": io_size + intermediate_size}
@staticmethod
def default_configs():
mlperf_inference = [1280, 1, 66, 1024]
nvidia = [128, 10, 128, 1024]
return [mlperf_inference, nvidia]
benchmark.register_benchmark_class(BahdanauAttention)