Closed
Description
Motivation
As a modular library, we'd like our components to be as orthogonal to each other as possible.
As such, we should aim at minimising the inter-dependency of the components (at least on what is exposed to the users).
An example of such a restriction is the fact that our data collectors expect policies that are TensorDictModule
instances.
It should be possible to automatically build a TensorDictModule just by reading the signature of a policy.
Example:
policy = MyRegularModule()
collector = datacollector(make_env, policy)
assert isinstance(collector.policy, TensorDictModule)
assert collector.policy.out_keys == ["action"]
assert collector.policy.module is policy
We would assume that the first output of the policy is an action (and has the "action" key).
Checks when passing the policy
- If the policy is not an nn.Module, there's nothing we can do with it.
- The policy signature must match the environment
observation_spec
(without the"next_"
prefixes), eg the following code snippet would only work if the env has keys "next_observation" and "next_stuff" in its observation_spec:class MyPolicy(nn.Module): def forward(self, observation, stuff): return foo()
Open questions:
- how can we support policies that have multiple outputs? (e.g. a log_probability, a hidden state etc)?
We could name these "output1" ... "outputN"?