Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [Torch] [ONNX] GRU layer #8781

Merged
merged 9 commits into from
Aug 25, 2021
Prev Previous commit
Next Next commit
GRU cell in ONNX frontend was used from common.py. previous implement…
…ation was removed
  • Loading branch information
vvchernov committed Aug 24, 2021
commit 677eafbec7b810129697e014509fdf2ceca4de5f
150 changes: 72 additions & 78 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
infer_value,
new_var,
unbind,
gru_cell,
lstm_cell,
)

Expand Down Expand Up @@ -2349,57 +2350,41 @@ class GRU(RNN):
"""Operator convert for GRU"""

@classmethod
def generate_gru(
cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False
def bidir_gru_cell(
cls,
input_seqs,
weight_dicts,
acts,
):
"""Create an unrolled gru loop.

See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
"""
h_list = []
seq_length = len(X_steps)
for i in range(seq_length):
step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)]
step = _op.squeeze(step, axis=[0])
current = _op.nn.dense(step, W)
cz, cr, ch = _op.split(current, 3, axis=1)
rz, rr, rh = _op.split(R, 3, axis=0)
z = cz + _op.nn.dense(H_t, rz)
r = cr + _op.nn.dense(H_t, rr)
if B is not None:
WB, RB = _op.split(B, 2)
wbz, wbr, wbh = _op.split(WB, 3, axis=-1)
rbz, rbr, rbh = _op.split(RB, 3, axis=-1)
z += wbz + rbz
r += wbr + rbr
r = f_act(r)
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh
else:
h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh
else:
r = f_act(r)
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh)))
else:
h = ch + _op.nn.dense((r * H_t), rh)

z = f_act(z)
h = g_act(h)

H_t = (H_t - h) * z + h
h_list.append(_op.expand_dims(H_t, axis=0))
Bidirectional GRU cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = gru_cell(
input_seqs,
**weight_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

if backwards:
# Canonical view is hidden states from the first token not last
h_list = h_list[::-1]
reverse_outputs, rev_H_t = gru_cell(
input_seqs,
**weight_dicts[1],
rz_act=acts[2],
n_act=acts[3],
backwards=True,
)

# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0)
)

return output, H_t
return (
_op.stack(final_outputs, axis=0),
_op.stack([fw_H_t, rev_H_t], axis=0),
)

@classmethod
def _impl_v7(cls, inputs, attr, params):
Expand All @@ -2417,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params):
W_dtype = infer_type(Wp).checked_type.dtype

if num_directions not in [1, 2]:
raise NotImplementedError(
f"Directions for GRUs should be either 1 or 2 got {num_directions}"
)
raise ValueError("num_directions must be either 1 or 2!")

X_shape = infer_shape(X)
hidden_size = infer_shape(Rp)[-1]
batch_size = X_shape[1]

# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if Hp_0 is None:
Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Bp is None:
Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype)

if "activations" in attr:
activations = attr["activations"]
Expand Down Expand Up @@ -2461,39 +2440,54 @@ def _impl_v7(cls, inputs, attr, params):
else:
acts = [_op.sigmoid, _op.tanh] * 2

result_output = []
result_H = []
# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
X_steps = unbind(X, axis=0)

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
H_ts = _op.split(Hp_0, num_directions)
Ws = _op.split(Wp, num_directions)
Rs = _op.split(Rp, num_directions)
Bs = _op.split(Bp, num_directions)

if Bp is not None:
Bs = _op.split(Bp, num_directions)

weights_dicts = []
for i in range(num_directions):
H_t = _op.squeeze(H_ts[i], axis=[0])
W = _op.squeeze(Ws[i], axis=[0])
R = _op.squeeze(Rs[i], axis=[0])
B = _op.squeeze(Bs[i], axis=[0])
f_act, g_act = acts[i * 2 : (i + 1) * 2]
output, H = GRU.generate_gru(
X_steps=X_steps,
H_t=H_t,
W=W,
R=R,
B=B,
linear_before_reset=linear_before_reset,
f_act=f_act,
g_act=g_act,
W_dtype=W_dtype,
backwards=i == 1,
)
weights_dict = {}

weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
weights_dict["linear_before_reset"] = linear_before_reset

# Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3)
weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3)
weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0)
if Bp is not None:
Bi, Bh = _op.split(Bs[i], 2, -1)
matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3)
weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3)
weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0)
weights_dicts.append(weights_dict)

result_output.append(output)
result_H.append(H)
if num_directions == 2:
output, H = GRU.bidir_gru_cell(
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
outputs, H = gru_cell(
input_seqs=X_steps,
**weights_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)
# output shape = (seqs_num, num_directions, batch_size, hidden_size)
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
H = _op.expand_dims(H, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)

Expand Down