Description
Currently, when you look at the examples (test cases), you see mostly such code:
class Model(rc.nn.Module):
def __init__(self):
super().__init__()
self.linear = rc.nn.Linear(n_out=13)
def forward(self) -> rc.nn.LayerRef:
x = rc.nn.get_extern_data("data")
x = self.linear(x)
return x
model = Model()
net_dict = model.make_root_net_dict()
I don't like this too much, as Model
does not get any inputs, but uses rc.nn.get_extern_data
instead. This makes the code of Model
not easily reusable in other context.
An alternative currently is also this variant:
linear = Linear(n_out=13)
with NameCtx.new_root() as name_ctx:
out = linear(get_extern_data("data"))
name_ctx.make_default_output(out)
net_dict = name_ctx.make_net_dict()
This is equivalent.
But now you have not created any Model
at all, so you also cannot easily reuse this code as a module or building block in some other context. So this is also not a good solution.
What I want somehow is this definition of the model:
class Model(rc.nn.Module):
def __init__(self):
super().__init__()
self.linear = rc.nn.Linear(n_out=13)
def forward(self, x: rc.nn.LayerRef) -> rc.nn.LayerRef:
return self.linear(x)
model = Model()
But then, where do you get the extern data in, or rather, how do you connect the extern data to the inputs of model
? And how do you get the net dict in the end?
Of course, you could do this now:
with NameCtx.new_root() as name_ctx:
out = model(get_extern_data("data"))
name_ctx.make_default_output(out)
net_dict = name_ctx.make_net_dict()
However, that might not be the behavior as you want, as you now would get one big subnetwork named Model
. But you maybe want that linear
is directly a layer in the net dict, and not just in some subnetwork.
You can also already do this:
with NameCtx.new_root() as name_ctx:
out = model(get_extern_data("data"), name=name_ctx)
name_ctx.make_default_output(out)
net_dict = name_ctx.make_net_dict()
With name=name_ctx
, you explicitly tell it to use the root name scope as the name, which has the effect that it will not become a subnetwork.
But we should maybe also introduce a better make_root_net_dict
. Maybe like:
net_dict = make_root_net_dict(model, x="data")
Other suggestions?