Skip to content

[BugFix] Fix shifted value computation with an LSTM #2941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2025

Conversation

vmoens
Copy link
Collaborator

@vmoens vmoens commented May 7, 2025

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2941

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vmoens pushed a commit that referenced this pull request May 7, 2025
ghstack-source-id: b912177
Pull-Request-resolved: #2941
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 7, 2025
[ghstack-poisoned]
vmoens pushed a commit that referenced this pull request May 7, 2025
ghstack-source-id: ac120cd
Pull-Request-resolved: #2941
[ghstack-poisoned]
vmoens pushed a commit that referenced this pull request May 13, 2025
ghstack-source-id: a9d9bc9
Pull-Request-resolved: #2941
@vmoens vmoens added the bug Something isn't working label May 13, 2025
[ghstack-poisoned]
vmoens pushed a commit that referenced this pull request May 13, 2025
ghstack-source-id: 08b4697
Pull-Request-resolved: #2941
Comment on lines 564 to 570
# Test:
# import torch
# from tensordict import TensorDict
# a = torch.randn(3, 4)
# b = TensorDict(a=torch.randn(3, 4), batch_size=3)
# ds = (0, 0)
# v0, v1, v2 = zip(*tuple(tree_map(lambda d, x: x.unbind(d), (0, 0), (a, b))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be removed before merging?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep i'll make a test out of it

Comment on lines 445 to 448
for i, name in enumerate(data.names):
if name == "time":
ndim = i + 1
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should we instead rely on ndim = data.names.index("time") + 1 with a try/catch instead? Would be cleaner imo.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. The reason I wasn't not doing it is that the try/except isn't compatible with compile but in practice, once we reach the warning we're in a non-compilable code anyway.

@@ -435,62 +441,70 @@ def _call_value_nets(
value_net = self.value_network
in_keys = value_net.in_keys
if single_call:
# Reshape to -1 because we cannot guarantee that all dims have the same number of done states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow why this comment is here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be moved down

slice(data.shape[ndim - 1], None),
data_copy = data.copy()
done = data_copy["next", "done"].clone()
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)
# Mark the last step of every sequence as done.
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True)

idx_ = (slice(None),) * (ndim - 1) + (
slice(data.shape[ndim - 1], None),
data_copy = data.copy()
done = data_copy["next", "done"].clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll admit having quite a hard time following the logic that lays below, but more importantly why it is needed. I think a few comments could really help the next reader :)

Comment on lines 13800 to 13801
for p in recurrent_module.parameters():
p.data *= 1 + torch.randn_like(p.data) / 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious as to why overwriting the RNN parameters?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did that to make sure we had enough power to detect slight changes in behavior (ie that every item that could be numerically different than the same computed with the opposite shifted param would be), but in practice it's less useful now

recurrent_module,
Mod(mlp_policy, in_keys=["intermediate"], out_keys=["action"]),
)
# value_net.select_out_keys("state_value")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# value_net.select_out_keys("state_value")

# value_net.select_out_keys("state_value")
env = env.append_transform(recurrent_module.make_tensordict_primer())
vals = env.rollout(1000, policy_net, break_when_any_done=False)
value_net(vals.copy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this call to the value net is not necessary?

Suggested change
value_net(vals.copy())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's there for init purposes

Comment on lines 13828 to 13829
# TODO: where to put this?
vals["next", "is_init"] = vals["is_init"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we solve this TODO before merging?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! I think this is something we can safely ask GAE to do for us

Comment on lines +791 to +792
_hidden0_in = hidden0_in[..., 0, :, :]
_hidden1_in = hidden1_in[..., 0, :, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is interesting. I'm guessing the shape of hiddenN_in is (batch, sequence, num_lstm_layers, hidden_size).

Were we before not using the first hidden state of a sequence, but the first hidden layer in the LSTM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's more an aesthetic change than anything. Just after we swap dims based on negative indices, so using negative indices all the way through helps clarify things. In both cases, we're indexing the 0th element of the sequence.

[ghstack-poisoned]
vmoens pushed a commit that referenced this pull request May 14, 2025
ghstack-source-id: 9ccbf82
Pull-Request-resolved: #2941
@vmoens vmoens merged commit 9f5aaef into gh/vmoens/142/base May 14, 2025
50 of 69 checks passed
vmoens pushed a commit that referenced this pull request May 14, 2025
ghstack-source-id: 9ccbf82
Pull-Request-resolved: #2941
@vmoens vmoens deleted the gh/vmoens/142/head branch May 14, 2025 08:51
vmoens pushed a commit that referenced this pull request May 16, 2025
ghstack-source-id: 9ccbf82
Pull-Request-resolved: #2941
(cherry picked from commit 1813e8e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants