4747 infer_value ,
4848 new_var ,
4949 unbind ,
50+ gru_cell ,
5051 lstm_cell ,
5152)
5253
@@ -2349,56 +2350,41 @@ class GRU(RNN):
23492350 """Operator convert for GRU"""
23502351
23512352 @classmethod
2352- def generate_gru (
2353- cls , X_steps , H_t , W , R , B , linear_before_reset , f_act , g_act , W_dtype , backwards = False
2353+ def bidir_gru_cell (
2354+ cls ,
2355+ input_seqs ,
2356+ weight_dicts ,
2357+ acts ,
23542358 ):
2355- """Create an unrolled gru loop.
2356-
2357- See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
23582359 """
2359- h_list = []
2360- seq_length = len (X_steps )
2361- for i in range (seq_length ):
2362- step = X_steps [i ] if not backwards else X_steps [seq_length - (i + 1 )]
2363- step = _op .squeeze (step , axis = [0 ])
2364- current = _op .nn .dense (step , W )
2365- cz , cr , ch = _op .split (current , 3 , axis = 1 )
2366- rz , rr , rh = _op .split (R , 3 , axis = 0 )
2367- z = cz + _op .nn .dense (H_t , rz )
2368- r = cr + _op .nn .dense (H_t , rr )
2369- if B is not None :
2370- WB , RB = _op .split (B , 2 )
2371- wbz , wbr , wbh = _op .split (WB , 3 , axis = - 1 )
2372- rbz , rbr , rbh = _op .split (RB , 3 , axis = - 1 )
2373- z += wbz + rbz
2374- r += wbr + rbr
2375- if linear_before_reset :
2376- h = ch + (r * (_op .nn .dense (H_t , rh ) + rbh )) + wbh
2377- else :
2378- h = ch + _op .nn .dense ((r * H_t ), rh ) + wbh + rbh
2379- else :
2380- if linear_before_reset :
2381- h = ch + (r * (_op .nn .dense (H_t , rh )))
2382- else :
2383- h = ch + _op .nn .dense ((r * H_t ), rh )
2384-
2385- z = f_act (z )
2386- r = f_act (r )
2387- h = g_act (h )
2388-
2389- H_t = ((_expr .const (1 , dtype = W_dtype ) - z ) * h ) + (z * H_t )
2390- h_list .append (_op .expand_dims (H_t , axis = 0 ))
2360+ Bidirectional GRU cell
2361+ """
2362+ seq_len = len (input_seqs )
2363+ forward_outputs , fw_H_t = gru_cell (
2364+ input_seqs ,
2365+ ** weight_dicts [0 ],
2366+ rz_act = acts [0 ],
2367+ n_act = acts [1 ],
2368+ )
23912369
2392- if backwards :
2393- # Canonical view is hidden states from the first token not last
2394- h_list = h_list [::- 1 ]
2370+ reverse_outputs , rev_H_t = gru_cell (
2371+ input_seqs ,
2372+ ** weight_dicts [1 ],
2373+ rz_act = acts [2 ],
2374+ n_act = acts [3 ],
2375+ backwards = True ,
2376+ )
23952377
2396- # Concatenate outputs and add back in direction axis.
2397- concatenated = _op .concatenate (h_list , 0 )
2398- output = _op .expand_dims (concatenated , axis = 1 )
2399- H_t = _op .expand_dims (H_t , axis = 0 )
2378+ final_outputs = []
2379+ for i in range (seq_len ):
2380+ final_outputs .append (
2381+ _op .stack ([forward_outputs [i ], reverse_outputs [seq_len - 1 - i ]], axis = 0 )
2382+ )
24002383
2401- return output , H_t
2384+ return (
2385+ _op .stack (final_outputs , axis = 0 ),
2386+ _op .stack ([fw_H_t , rev_H_t ], axis = 0 ),
2387+ )
24022388
24032389 @classmethod
24042390 def _impl_v7 (cls , inputs , attr , params ):
@@ -2416,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params):
24162402 W_dtype = infer_type (Wp ).checked_type .dtype
24172403
24182404 if num_directions not in [1 , 2 ]:
2419- raise NotImplementedError (
2420- f"Directions for GRUs should be either 1 or 2 got { num_directions } "
2421- )
2405+ raise ValueError ("num_directions must be either 1 or 2!" )
24222406
24232407 X_shape = infer_shape (X )
24242408 hidden_size = infer_shape (Rp )[- 1 ]
24252409 batch_size = X_shape [1 ]
24262410
2427- # Initialize state if not provided.
2428- # Otherwise remove bidirectional axis.
24292411 if Hp_0 is None :
24302412 Hp_0 = _op .zeros ((num_directions , batch_size , hidden_size ), W_dtype )
2431- if Bp is None :
2432- Bp = _op .zeros ((num_directions , hidden_size * 6 ), W_dtype )
24332413
24342414 if "activations" in attr :
24352415 activations = attr ["activations" ]
@@ -2460,39 +2440,54 @@ def _impl_v7(cls, inputs, attr, params):
24602440 else :
24612441 acts = [_op .sigmoid , _op .tanh ] * 2
24622442
2463- result_output = []
2464- result_H = []
2443+ # TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
2444+ X_steps = unbind ( X , axis = 0 )
24652445
2466- X_steps = _op .split (X , indices_or_sections = X_shape [0 ], axis = 0 )
24672446 H_ts = _op .split (Hp_0 , num_directions )
24682447 Ws = _op .split (Wp , num_directions )
24692448 Rs = _op .split (Rp , num_directions )
2470- Bs = _op .split (Bp , num_directions )
24712449
2450+ if Bp is not None :
2451+ Bs = _op .split (Bp , num_directions )
2452+
2453+ weights_dicts = []
24722454 for i in range (num_directions ):
2473- H_t = _op . squeeze ( H_ts [ i ], axis = [ 0 ])
2474- W = _op . squeeze ( Ws [ i ], axis = [ 0 ])
2475- R = _op .squeeze (Rs [i ], axis = [0 ])
2476- B = _op . squeeze ( Bs [ i ], axis = [ 0 ])
2477- f_act , g_act = acts [ i * 2 : ( i + 1 ) * 2 ]
2478- output , H = GRU . generate_gru (
2479- X_steps = X_steps ,
2480- H_t = H_t ,
2481- W = W ,
2482- R = R ,
2483- B = B ,
2484- linear_before_reset = linear_before_reset ,
2485- f_act = f_act ,
2486- g_act = g_act ,
2487- W_dtype = W_dtype ,
2488- backwards = i == 1 ,
2489- )
2455+ weights_dict = {}
2456+
2457+ weights_dict [ "hidden_state" ] = _op .squeeze (H_ts [i ], axis = [0 ])
2458+ weights_dict [ "linear_before_reset" ] = linear_before_reset
2459+
2460+ # Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
2461+ matz , matr , matn = _op . split ( _op . squeeze ( Ws [ i ], axis = [ 0 ]), 3 )
2462+ weights_dict [ "w_inp" ] = _op . concatenate ([ matr , matz , matn ], axis = 0 )
2463+ matz , matr , matn = _op . split ( _op . squeeze ( Rs [ i ], axis = [ 0 ]), 3 )
2464+ weights_dict [ "w_hid" ] = _op . concatenate ([ matr , matz , matn ], axis = 0 )
2465+ if Bp is not None :
2466+ Bi , Bh = _op . split ( Bs [ i ], 2 , - 1 )
2467+ matz , matr , matn = _op . split ( _op . squeeze ( Bi , axis = [ 0 ]), 3 )
2468+ weights_dict [ "b_inp" ] = _op . concatenate ([ matr , matz , matn ], axis = 0 )
2469+ matz , matr , matn = _op . split ( _op . squeeze ( Bh , axis = [ 0 ]), 3 )
2470+ weights_dict [ "b_hid" ] = _op . concatenate ([ matr , matz , matn ], axis = 0 )
2471+ weights_dicts . append ( weights_dict )
24902472
2491- result_output .append (output )
2492- result_H .append (H )
2473+ if num_directions == 2 :
2474+ output , H = GRU .bidir_gru_cell (
2475+ input_seqs = X_steps ,
2476+ weight_dicts = weights_dicts ,
2477+ acts = acts ,
2478+ )
2479+ else :
2480+ # outputs shape = [seqs_num, (batch_size, hidden_size)]
2481+ outputs , H = gru_cell (
2482+ input_seqs = X_steps ,
2483+ ** weights_dicts [0 ],
2484+ rz_act = acts [0 ],
2485+ n_act = acts [1 ],
2486+ )
24932487
2494- output = _op .concatenate (result_output , axis = 1 )
2495- H = _op .concatenate (result_H , axis = 0 )
2488+ # output shape = (seqs_num, num_directions, batch_size, hidden_size)
2489+ output = _op .expand_dims (_op .stack (outputs , axis = 0 ), axis = 1 )
2490+ H = _op .expand_dims (H , axis = 0 )
24962491
24972492 return _expr .TupleWrapper (_expr .Tuple ((output , H )), 2 )
24982493
0 commit comments