Skip to content

Commit b819364

Browse files
vvchernovValery Chernov
andauthored
[Frontend] [Torch] [ONNX] GRU layer (#8781)
* GRU cell was implemented in common.py. GRU was supported on pytorch frontend side * update GRU in common.py and onnx frontend * fix issue related to GRU accuracy in pytorch and ONNX frontend * small fixes and remove excess * common GRU was additionaly updated. tuned pytorch GRU was strongly accelerated * GRU cell in ONNX frontend was used from common.py. previous implementation was removed * small fixes in comments * fixes after review. GRU test was implemented for pytorch frontend * tests for RNN layers was unified for pytorch frontend Co-authored-by: Valery Chernov <valery.chernov@deelvin.com>
1 parent 02b57a6 commit b819364

File tree

5 files changed

+774
-441
lines changed

5 files changed

+774
-441
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,90 @@ def unbind(data, axis=0):
658658
return _expr.TupleWrapper(_expr.Tuple(ret), selections)
659659

660660

661+
def gru_cell(
662+
input_seqs,
663+
hidden_state,
664+
w_inp,
665+
w_hid,
666+
b_inp=None,
667+
b_hid=None,
668+
rz_act=_op.sigmoid,
669+
n_act=_op.tanh,
670+
backwards=False,
671+
linear_before_reset=True,
672+
):
673+
"""
674+
Common implementation of GRU cell for all frontends of TVM
675+
TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for other frontends
676+
677+
Parameters
678+
----------
679+
input_seqs : List[relay.Expr]
680+
The sequence of input tensors
681+
Input tensor should be 2d while issue #8412 is not resolved
682+
Shape = (batch, feature_size)
683+
hidden_state : relay.Expr
684+
Hidden state. shape = (batch_size, hidden_size)
685+
w_inp, w_hid : relay.Expr
686+
weight matrices. wi shape = (3 * hidden_size, feature_size)
687+
wh shape = (3 * hidden_size, hidden_size)
688+
NOTE: wi = (w_ir|w_iz|w_in) for reset, update and new gates.
689+
The order is important for correct GRU calculation!
690+
b_inp, b_hid : relay.Expr
691+
bias matrices. The same order of internal parts as for weights. shape = (3 * hidden_size)
692+
r_act : relay.op
693+
activation funtion for reset gate. it is sigmoid by default
694+
z_act : relay.op
695+
activation funtion for update gate. it is sigmoid by default
696+
n_act : relay.op
697+
activation funtion for new gate. it is tanh by default
698+
backwards : bool
699+
Flag for reverse pass of GRU
700+
701+
Returns
702+
-------
703+
result : List[relay.Expr], relay.Expr, relay.Expr
704+
The sequence of computed result, final hidden and cell state
705+
"""
706+
707+
outputs_list = []
708+
for x_t in input_seqs if not backwards else reversed(input_seqs):
709+
xwt = _op.nn.dense(x_t, w_inp)
710+
if linear_before_reset:
711+
hwt = _op.nn.dense(hidden_state, w_hid)
712+
if b_inp is not None and b_hid is not None:
713+
xwt += b_inp
714+
hwt += b_hid
715+
i_r, i_z, i_n = _op.split(xwt, 3, axis=-1)
716+
h_r, h_z, h_n = _op.split(hwt, 3, axis=-1)
717+
r_gate = rz_act(i_r + h_r)
718+
z_gate = rz_act(i_z + h_z)
719+
n_gate = n_act(i_n + r_gate * h_n)
720+
else:
721+
i_r, i_z, i_n = _op.split(xwt, 3, axis=1)
722+
w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
723+
r_gate = i_r + _op.nn.dense(hidden_state, w_hr)
724+
z_gate = i_z + _op.nn.dense(hidden_state, w_hz)
725+
if b_inp is not None and b_hid is not None:
726+
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
727+
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
728+
r_gate += b_ir + b_hr
729+
z_gate += b_iz + b_hz
730+
i_n += b_in
731+
h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
732+
else:
733+
h_n = _op.nn.dense((r_gate * hidden_state), w_hn)
734+
r_gate = rz_act(r_gate)
735+
z_gate = rz_act(z_gate)
736+
n_gate = n_act(i_n + h_n)
737+
738+
hidden_state = (hidden_state - n_gate) * z_gate + n_gate
739+
740+
outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]
741+
742+
return outputs_list, hidden_state
743+
744+
661745
def lstm_cell(
662746
input_seqs,
663747
hidden_state,

python/tvm/relay/frontend/onnx.py

Lines changed: 72 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
infer_value,
4848
new_var,
4949
unbind,
50+
gru_cell,
5051
lstm_cell,
5152
)
5253

@@ -2349,56 +2350,41 @@ class GRU(RNN):
23492350
"""Operator convert for GRU"""
23502351

23512352
@classmethod
2352-
def generate_gru(
2353-
cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False
2353+
def bidir_gru_cell(
2354+
cls,
2355+
input_seqs,
2356+
weight_dicts,
2357+
acts,
23542358
):
2355-
"""Create an unrolled gru loop.
2356-
2357-
See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
23582359
"""
2359-
h_list = []
2360-
seq_length = len(X_steps)
2361-
for i in range(seq_length):
2362-
step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)]
2363-
step = _op.squeeze(step, axis=[0])
2364-
current = _op.nn.dense(step, W)
2365-
cz, cr, ch = _op.split(current, 3, axis=1)
2366-
rz, rr, rh = _op.split(R, 3, axis=0)
2367-
z = cz + _op.nn.dense(H_t, rz)
2368-
r = cr + _op.nn.dense(H_t, rr)
2369-
if B is not None:
2370-
WB, RB = _op.split(B, 2)
2371-
wbz, wbr, wbh = _op.split(WB, 3, axis=-1)
2372-
rbz, rbr, rbh = _op.split(RB, 3, axis=-1)
2373-
z += wbz + rbz
2374-
r += wbr + rbr
2375-
if linear_before_reset:
2376-
h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh
2377-
else:
2378-
h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh
2379-
else:
2380-
if linear_before_reset:
2381-
h = ch + (r * (_op.nn.dense(H_t, rh)))
2382-
else:
2383-
h = ch + _op.nn.dense((r * H_t), rh)
2384-
2385-
z = f_act(z)
2386-
r = f_act(r)
2387-
h = g_act(h)
2388-
2389-
H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t)
2390-
h_list.append(_op.expand_dims(H_t, axis=0))
2360+
Bidirectional GRU cell
2361+
"""
2362+
seq_len = len(input_seqs)
2363+
forward_outputs, fw_H_t = gru_cell(
2364+
input_seqs,
2365+
**weight_dicts[0],
2366+
rz_act=acts[0],
2367+
n_act=acts[1],
2368+
)
23912369

2392-
if backwards:
2393-
# Canonical view is hidden states from the first token not last
2394-
h_list = h_list[::-1]
2370+
reverse_outputs, rev_H_t = gru_cell(
2371+
input_seqs,
2372+
**weight_dicts[1],
2373+
rz_act=acts[2],
2374+
n_act=acts[3],
2375+
backwards=True,
2376+
)
23952377

2396-
# Concatenate outputs and add back in direction axis.
2397-
concatenated = _op.concatenate(h_list, 0)
2398-
output = _op.expand_dims(concatenated, axis=1)
2399-
H_t = _op.expand_dims(H_t, axis=0)
2378+
final_outputs = []
2379+
for i in range(seq_len):
2380+
final_outputs.append(
2381+
_op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0)
2382+
)
24002383

2401-
return output, H_t
2384+
return (
2385+
_op.stack(final_outputs, axis=0),
2386+
_op.stack([fw_H_t, rev_H_t], axis=0),
2387+
)
24022388

24032389
@classmethod
24042390
def _impl_v7(cls, inputs, attr, params):
@@ -2416,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params):
24162402
W_dtype = infer_type(Wp).checked_type.dtype
24172403

24182404
if num_directions not in [1, 2]:
2419-
raise NotImplementedError(
2420-
f"Directions for GRUs should be either 1 or 2 got {num_directions}"
2421-
)
2405+
raise ValueError("num_directions must be either 1 or 2!")
24222406

24232407
X_shape = infer_shape(X)
24242408
hidden_size = infer_shape(Rp)[-1]
24252409
batch_size = X_shape[1]
24262410

2427-
# Initialize state if not provided.
2428-
# Otherwise remove bidirectional axis.
24292411
if Hp_0 is None:
24302412
Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
2431-
if Bp is None:
2432-
Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype)
24332413

24342414
if "activations" in attr:
24352415
activations = attr["activations"]
@@ -2460,39 +2440,54 @@ def _impl_v7(cls, inputs, attr, params):
24602440
else:
24612441
acts = [_op.sigmoid, _op.tanh] * 2
24622442

2463-
result_output = []
2464-
result_H = []
2443+
# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
2444+
X_steps = unbind(X, axis=0)
24652445

2466-
X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
24672446
H_ts = _op.split(Hp_0, num_directions)
24682447
Ws = _op.split(Wp, num_directions)
24692448
Rs = _op.split(Rp, num_directions)
2470-
Bs = _op.split(Bp, num_directions)
24712449

2450+
if Bp is not None:
2451+
Bs = _op.split(Bp, num_directions)
2452+
2453+
weights_dicts = []
24722454
for i in range(num_directions):
2473-
H_t = _op.squeeze(H_ts[i], axis=[0])
2474-
W = _op.squeeze(Ws[i], axis=[0])
2475-
R = _op.squeeze(Rs[i], axis=[0])
2476-
B = _op.squeeze(Bs[i], axis=[0])
2477-
f_act, g_act = acts[i * 2 : (i + 1) * 2]
2478-
output, H = GRU.generate_gru(
2479-
X_steps=X_steps,
2480-
H_t=H_t,
2481-
W=W,
2482-
R=R,
2483-
B=B,
2484-
linear_before_reset=linear_before_reset,
2485-
f_act=f_act,
2486-
g_act=g_act,
2487-
W_dtype=W_dtype,
2488-
backwards=i == 1,
2489-
)
2455+
weights_dict = {}
2456+
2457+
weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
2458+
weights_dict["linear_before_reset"] = linear_before_reset
2459+
2460+
# Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
2461+
matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3)
2462+
weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0)
2463+
matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3)
2464+
weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0)
2465+
if Bp is not None:
2466+
Bi, Bh = _op.split(Bs[i], 2, -1)
2467+
matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3)
2468+
weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0)
2469+
matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3)
2470+
weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0)
2471+
weights_dicts.append(weights_dict)
24902472

2491-
result_output.append(output)
2492-
result_H.append(H)
2473+
if num_directions == 2:
2474+
output, H = GRU.bidir_gru_cell(
2475+
input_seqs=X_steps,
2476+
weight_dicts=weights_dicts,
2477+
acts=acts,
2478+
)
2479+
else:
2480+
# outputs shape = [seqs_num, (batch_size, hidden_size)]
2481+
outputs, H = gru_cell(
2482+
input_seqs=X_steps,
2483+
**weights_dicts[0],
2484+
rz_act=acts[0],
2485+
n_act=acts[1],
2486+
)
24932487

2494-
output = _op.concatenate(result_output, axis=1)
2495-
H = _op.concatenate(result_H, axis=0)
2488+
# output shape = (seqs_num, num_directions, batch_size, hidden_size)
2489+
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
2490+
H = _op.expand_dims(H, axis=0)
24962491

24972492
return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)
24982493

0 commit comments

Comments
 (0)