Skip to content

[Feature Request] Allow data collectors to receive policies that read and write regular tensors #544

Closed
@vmoens

Description

@vmoens

Motivation

As a modular library, we'd like our components to be as orthogonal to each other as possible.
As such, we should aim at minimising the inter-dependency of the components (at least on what is exposed to the users).
An example of such a restriction is the fact that our data collectors expect policies that are TensorDictModule instances.
It should be possible to automatically build a TensorDictModule just by reading the signature of a policy.

Example:

policy = MyRegularModule()
collector = datacollector(make_env, policy)
assert isinstance(collector.policy, TensorDictModule) 
assert collector.policy.out_keys == ["action"] 
assert collector.policy.module is policy 

We would assume that the first output of the policy is an action (and has the "action" key).

Checks when passing the policy

  • If the policy is not an nn.Module, there's nothing we can do with it.
  • The policy signature must match the environment observation_spec (without the "next_" prefixes), eg the following code snippet would only work if the env has keys "next_observation" and "next_stuff" in its observation_spec:
    class MyPolicy(nn.Module):
        def forward(self, observation, stuff):
            return foo()
    

Open questions:

  • how can we support policies that have multiple outputs? (e.g. a log_probability, a hidden state etc)?
    We could name these "output1" ... "outputN"?

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions