Skip to content

Commit c36f884

Browse files
authored
SyntaxError: invalid syntax in the def fused_lstm_gates - #6 in python 2.7
See: https://stackoverflow.com/questions/15301999/python-2-x-default-arguments-with-args-and-kwargs Note, it looks like, just replacing the order of *args and name wouldn't work: >> def lstm_gates2_op(c, a1, a2, name): >> print c, a1, a2, name >> >> def fused_lstm_gates(c, name=None, *args): >> print len(args) >> assert len(args) == 2 >> return lstm_gates2_op(c, *args, name=name) >> fused_lstm_gates(1, 2, 3, 4) 2 1 3 4 2 >> fused_lstm_gates(1, 2, 3, name=4) TypeError: fused_lstm_gates() got multiple values for keyword argument 'name' While the following is: >> def lstm_gates2_op(c, a1, a2, name): >> print c, a1, a2, name >> >> def fused_lstm_gates(c, *args, **kwargs): >> name = kwargs.pop('name', None) >> assert len(args) == 2 >> return lstm_gates2_op(c, *args, name=name) >> >> fused_lstm_gates(1, 2, 3) 1 2 3 None >> fused_lstm_gates(1, 2, 3, "test") 1 2 3 test
1 parent 8095147 commit c36f884

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

blocksparse/ewops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,11 @@ def ew_z_xb_grad(op, dz):
159159
lstm_gates4_op = _op_module.lstm_gates4
160160
lstm_gates4_grad_op = _op_module.lstm_gates4_grad
161161

162-
def fused_lstm_gates(c, *args, name=None):
162+
def fused_lstm_gates(c, *args, **kwargs):
163163
# returns c_next, h_next
164-
164+
165+
assert len(kwargs) <= 1
166+
name = kwargs.pop('name', None)
165167
# args is h (all four gates fused in single tensor)
166168
if len(args) == 1:
167169
return lstm_gates_op(c, args[0], name=name)

0 commit comments

Comments
 (0)