Skip to content

Rec design for recurrent definitions / loops #16

Closed
@albertz

Description

@albertz

This issue is to collect some thoughts on the recurrent loops design, which wraps the RecLayer with an explicit subnetwork in RETURNN.

The main goal is to have this very straight-forward and simple for the user. We can abstract away from the underlying RecLayer if that makes things easier. We can also extend RETURNN itself if needed.

Related is also #6 (rec prev mechanism), and this issue here might fix/resolve #6, although not necessarily.

This also needs some mechanism for unrolling/unstacking, i.e. when we iterate over input x with some time-axis, i.e. to get x[t]. This is rwth-i6/returnn#552.


To define a loop like this pseudo Python code:

x  # given, shape {batch, time, dim}
h = Zeros({batch,dim})()  # initial state, shape {batch,dim}
out = []
for t in range(x.max_seq_len):
  x_lin = Linear(dim)(x[t])
  h_prev = h
  h = Linear(dim)(x_lin + h_prev)
  out.append(h)

h  # final state
out  # shape {time, batch, dim}

Current design:

There is Loop() which can be used in a with context, which corresponds to the for-loop in the example, or in general to a while-loop. Like:

with Loop() as loop:
  ...

There is State() which can define hidden state (for any module or any code).

The example above can be written as:

h = State({batch, dim}, initial=0)
with Loop() as loop:  # this introduces a new loop
  x_t = loop.unstack(x)  # shape {batch, dim}

  x_lin = Linear(dim)(x_t)
  h_prev = h.get()
  h_ = Linear(dim)(x_lin + h_prev)  # shape {batch, dim}
  h.assign(h_)

  out = loop.stack(h_)  # shape {time,batch,dim}
  h_last = loop.last(h_)

# h.get() would now return the last state
# h_last is an alternative

Or with a module as:

class MyRec(Module):
  def __init__(self):
    super().__init__()
    self.x_linear = Linear(dim)
    self.h_linear = Linear(dim)
    self.h = State({batch, dim}, initial=0)

  def forward(self, x):
    # x shape is {batch, dim}
    x_lin = self.x_linear(x)
    h_prev = self.h.get()
    h = self.h_linear(x_lin + h_prev)  # shape {batch, dim}
    self.h.assign(h)
    return h

rec = MyRec()
with Loop() as loop:  # this introduces a new loop
  x_t = loop.unstack(x)  # shape {batch, dim}
  h_ = rec(x_t)  # shape {batch,dim}. this represents the inner value
  h = loop.last(h_)  # shape {batch,dim}
  out = loop.stack(h_)  # shape {time,batch,dim}

For the TF name scopes (and variable scopes), we should follow #25, i.e. make it exactly as the module hierarchy.

The RETURNN layer name of the created RecLayer via Loop does not matter too much. It could be arbitrary, or some clever (but simple) logic to use the first module name or so. The RETURNN layer hierarchy can be independent from the actual TF name scopes (via #25).

Special options for the RecLayer like include_eos can be options for Loop, like Loop(include_eos=True). Or as a method, like loop.set_include_eos(True).

Loop (potential) methods:

State has methods get and assign. (... See discussion below for more ...)

Current reasonings:

Why no special base class Rec which derives from Module? We want to easily allow to use any kind of module inside a loop. We think the current API makes this more straight-forward.

Why is h not an argument of forward, and why State instead? This allows to call other sub modules, which might define their own hidden state. So the root recurrent module does not need to know about all the hidden states of sub modules.

Why to have the hidden state explicit, and not use sth more close to self.prev? To make the behavior more straight-forward.

The current design allows for nested loops and sub modules with hidden state.
Only the Loop() call actually introduces a new loop.

class MySubRec(Module):
  def __init__(self):
    super().__init__()
    self.h = State({batch,dim})

  def forward(self, a):
    # assume a shape {batch,dim}
    h = self.h.get() + a
    self.h.assign(h)
    return h

class MyRec(Module):
  def __init__(self):
    super().__init__()
    self.sub = MySubRec()
    self.h = State({batch,dim})

  def forward(self, x):
    a = self.h.get() + x

    # example with sub as nested loop
    with Loop() as loop:
      y = self.sub(a)
      y = loop.last(y)

    # or: example with sub in same loop
    y = self.sub(a)
    
    self.h.assign(y)
    return y

There should not be any special handling needed for the Choice layer.
Note that the search flag and train flag logic is a separate thing (#18).

There should not be any special handling needed whether the input to a rec module call would be inside the current/same loop or not. unstack on some value which is already inside the loop would not make sense, though, and should result in an error. But this would all be covered by RETURNN logic already.

RETURNN rec automatic optimization should not cause any problems. RETURNN already should guarantee that it is equivalent. From the user view point, it never ever should matter whether it is optimized. Otherwise this is rwth-i6/returnn#573. On this returnn-common level, it should not matter.


Example for LSTM for a single step:

class Lstm(Module):
  def __init__(self):
    super().__init__()
    self.h = State({batch,dim})
    self.c = State({batch,dim})
    self.ff_linear = Linear(dim * 4)
    self.rec_linear = Linear(dim * 4)

  def forward(self, x):
    # x shape is {batch,dim} (single frame)
    x_ = self.ff_linear(x)
    h_ = self.rec_linear(self.h.get())
    x_in, g_in, g_forget, g_out = split(x_ + h_, 4)
    c = self.c.get() * sigmoid(g_forget) + tanh(x_in) * sigmoid(g_in)
    self.c.assign(c)
    h = tanh(c) * sigmoid(g_out)
    self.h.assign(h)
    return h

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions