forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrnn_eltwise.py
105 lines (87 loc) · 3.15 KB
/
rnn_eltwise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from . import benchmark
import torch
class RNNEltwise(benchmark.Benchmark):
def __init__(self, mode, device, dtype, b, hs):
super().__init__(mode, device, dtype)
self.b = b
self.hs = hs
self.input = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.hx = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.cx = self.rand(
[b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_ih = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.b_hh = self.rand(
[b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
def forward(self, input, hx, cx, b_ih, b_hh):
gates = input + hx + b_ih + b_hh
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)
return hy, cy
def config(self):
return [self.b, self.hs]
@staticmethod
def module():
return "rnn_eltwise"
def memory_workload(self):
def memsize(t):
return t.numel() * t.element_size()
input_size = sum([memsize(t) for t in self.inputs])
output_size = 2 * memsize(self.cx)
io_size = input_size + output_size
return {"sol": io_size, "algorithmic": io_size}
@staticmethod
def default_configs():
return [[64, 512]]
benchmark.register_benchmark_class(RNNEltwise)
class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
def __init__(self, mode, device, dtype, b, hs):
benchmark.DynamicShape.__init__(self)
RNNEltwise.__init__(self, mode, device, dtype, b, hs)
def instantiate_input(self):
b, hs = self.rand_shape([self.b, self.hs])
self.input = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.hx = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.cx = self.rand(
[b, hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.b_ih = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.b_hh = self.rand(
[b, 4 * hs], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
)
self.inputs = [
self.input,
self.hx,
self.cx,
self.b_ih,
self.b_hh,
]
@staticmethod
def module():
return "dynamic_lstm"
benchmark.register_benchmark_class(DynamicLSTM)