Skip to content

Better way to define main network dict (make_root_net_dict) #44

Closed
@albertz

Description

@albertz

Currently, when you look at the examples (test cases), you see mostly such code:

class Model(rc.nn.Module):
  def __init__(self):
    super().__init__()
    self.linear = rc.nn.Linear(n_out=13)

  def forward(self) -> rc.nn.LayerRef:
    x = rc.nn.get_extern_data("data")
    x = self.linear(x)
    return x

model = Model()
net_dict = model.make_root_net_dict()

I don't like this too much, as Model does not get any inputs, but uses rc.nn.get_extern_data instead. This makes the code of Model not easily reusable in other context.

An alternative currently is also this variant:

linear = Linear(n_out=13)

with NameCtx.new_root() as name_ctx:
  out = linear(get_extern_data("data"))
  name_ctx.make_default_output(out)

  net_dict = name_ctx.make_net_dict()

This is equivalent.

But now you have not created any Model at all, so you also cannot easily reuse this code as a module or building block in some other context. So this is also not a good solution.

What I want somehow is this definition of the model:

class Model(rc.nn.Module):
  def __init__(self):
    super().__init__()
    self.linear = rc.nn.Linear(n_out=13)

  def forward(self, x: rc.nn.LayerRef) -> rc.nn.LayerRef:
    return self.linear(x)

model = Model()

But then, where do you get the extern data in, or rather, how do you connect the extern data to the inputs of model? And how do you get the net dict in the end?

Of course, you could do this now:

with NameCtx.new_root() as name_ctx:
  out = model(get_extern_data("data"))
  name_ctx.make_default_output(out)

  net_dict = name_ctx.make_net_dict()

However, that might not be the behavior as you want, as you now would get one big subnetwork named Model. But you maybe want that linear is directly a layer in the net dict, and not just in some subnetwork.

You can also already do this:

with NameCtx.new_root() as name_ctx:
  out = model(get_extern_data("data"), name=name_ctx)
  name_ctx.make_default_output(out)

  net_dict = name_ctx.make_net_dict()

With name=name_ctx, you explicitly tell it to use the root name scope as the name, which has the effect that it will not become a subnetwork.

But we should maybe also introduce a better make_root_net_dict. Maybe like:

net_dict = make_root_net_dict(model, x="data")

Other suggestions?

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions