Skip to content

Commit

Permalink
small fixes and remove excess
Browse files Browse the repository at this point in the history
  • Loading branch information
vvchernov committed Aug 18, 2021
1 parent ec402c6 commit 451f1c5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 44 deletions.
35 changes: 0 additions & 35 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,6 @@ def unbind(data, axis=0):
def gru_cell(
input_seqs,
hidden_state,
hidden_size,
w_inp,
w_hid,
b_inp=None,
Expand All @@ -683,8 +682,6 @@ def gru_cell(
Shape = (batch, feature_size)
hidden_state : relay.Expr
Hidden state. shape = (batch_size, hidden_size)
hidden_size : int
The number of features in the hidden state. It is needed for correct and quick split of weights.
w_inp, w_hid : relay.Expr
weight matrices. wi shape = (3 * hidden_size, feature_size)
wh shape = (3 * hidden_size, hidden_size)
Expand All @@ -709,38 +706,6 @@ def gru_cell(

outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
# x_t shape = (batch, feature size), step shape = (batch, feature size + hidden_size)
# step = _op.concatenate([x_t, hidden_state], axis=1)
# w_irz, w_in = _op.split(w_inp, [2*hidden_size], axis=0)
# w_hrz, w_hn = _op.split(w_hid, [2*hidden_size], axis=0)
# cat_w = _op.concatenate([w_irz, w_hrz], axis=1)
# # Instead of nn.dense(x_t, w_inp) + nn.dense(hidden_state, w_hid)
# # nn.dense(step, cat_w) is used
# # gates shape = (batch, 2 * hidden_size)
# rz_gates = _op.nn.dense(step, cat_w)
# # Add biases
# if b_inp is not None:
# b_irz, b_in = _op.split(b_inp, [2*hidden_size], axis=0)
# rz_gates += b_irz
# if b_hid is not None:
# b_hrz, b_hn = _op.split(b_hid, [2*hidden_size], axis=0)
# rz_gates += b_hrz
# # TODO(vvchernov): check similarity of r_act and z_act and change sequence act->split
# # any gate shape = (batch, hidden_size)
# r_gate, z_gate = _op.split(rz_gates, 2, axis=-1)

# r_gate = r_act(r_gate)
# z_gate = z_act(z_gate)

# ni_gate = _op.nn.dense(x_t, w_in)
# if b_inp is not None:
# ni_gate += b_in
# nh_gate = _op.nn.dense(hidden_state, w_hn)
# if b_hid is not None:
# nh_gate += b_hn

# n_gate = n_act(ni_gate + r_gate * nh_gate)

xwt = _op.nn.dense(x_t, w_inp)
i_r, i_z, i_n = _op.split(xwt, 3, axis=1)
w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
Expand Down
15 changes: 6 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,6 @@ def flip(self, inputs, input_types):
def bidir_gru_cell(
self,
input_seqs,
hidden_size,
weights_dicts,
):
"""
Expand All @@ -2327,13 +2326,11 @@ def bidir_gru_cell(
seq_len = len(input_seqs)
forward_outputs, fw_H_t = gru_cell(
input_seqs,
hidden_size=hidden_size,
**weights_dicts[0],
)

reverse_outputs, rev_H_t = gru_cell(
input_seqs,
hidden_size=hidden_size,
**weights_dicts[1],
backwards=True,
)
Expand All @@ -2346,7 +2343,7 @@ def bidir_gru_cell(

return final_outputs, _op.stack([fw_H_t, rev_H_t], axis=0)

def gru_layers(self, input_data, layer_weights_dicts, bidirectional, hidden_size, dropout_p=0.0):
def gru_layers(self, input_data, layer_weights_dicts, bidirectional, dropout_p=0.0):
"""
Methods iterates layers for Stacked LSTM
"""
Expand All @@ -2359,9 +2356,9 @@ def gru_layers(self, input_data, layer_weights_dicts, bidirectional, hidden_size
# input_seqs shape = [seq_num, (batch, feature_size)] or
# [seq_num, (batch, 2*feature_size)] for bidirectional
if bidirectional:
input_seqs, H_t = self.bidir_gru_cell(input_seqs, hidden_size, weights_dicts)
input_seqs, H_t = self.bidir_gru_cell(input_seqs, weights_dicts)
else:
input_seqs, H_t = gru_cell(input_seqs, **weights_dicts[0], hidden_size=hidden_size)
input_seqs, H_t = gru_cell(input_seqs, **weights_dicts[0])

output_hiddens.append(H_t)

Expand All @@ -2377,7 +2374,8 @@ def gru_layers(self, input_data, layer_weights_dicts, bidirectional, hidden_size

def gru(self, inputs, input_types):
"""
Description of GRU in pytorch:https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#torch.nn.GRU
Description of GRU in pytorch:
https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=gru#torch.nn.GRU
"""
# TODO (vvchernov): support dropout
assert len(inputs) == 9, "Input of size 9 is expected"
Expand Down Expand Up @@ -2430,7 +2428,7 @@ def gru(self, inputs, input_types):
X_dtype = input_types[0]
X_shape = _infer_shape(X) # (seq_num, batch, feature_size)

hidden_size = _infer_shape(_weights[0])[0] / 3
hidden_size = int(_infer_shape(_weights[0])[0] / 3)
batch_size = X_shape[1]

# Initialize hidden states if not provided.
Expand Down Expand Up @@ -2493,7 +2491,6 @@ def gru(self, inputs, input_types):
X,
layer_weights_dicts,
bidirectional,
hidden_size=hidden_size,
dropout_p=dropout_p,
)

Expand Down

0 comments on commit 451f1c5

Please sign in to comment.