-
Notifications
You must be signed in to change notification settings - Fork 400
[BugFix] Fix shifted value computation with an LSTM #2941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2941
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchrl/objectives/utils.py
Outdated
# Test: | ||
# import torch | ||
# from tensordict import TensorDict | ||
# a = torch.randn(3, 4) | ||
# b = TensorDict(a=torch.randn(3, 4), batch_size=3) | ||
# ds = (0, 0) | ||
# v0, v1, v2 = zip(*tuple(tree_map(lambda d, x: x.unbind(d), (0, 0), (a, b)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be removed before merging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep i'll make a test out of it
for i, name in enumerate(data.names): | ||
if name == "time": | ||
ndim = i + 1 | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Should we instead rely on ndim = data.names.index("time") + 1
with a try/catch instead? Would be cleaner imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point. The reason I wasn't not doing it is that the try/except isn't compatible with compile but in practice, once we reach the warning we're in a non-compilable code anyway.
@@ -435,62 +441,70 @@ def _call_value_nets( | |||
value_net = self.value_network | |||
in_keys = value_net.in_keys | |||
if single_call: | |||
# Reshape to -1 because we cannot guarantee that all dims have the same number of done states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I follow why this comment is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be moved down
slice(data.shape[ndim - 1], None), | ||
data_copy = data.copy() | ||
done = data_copy["next", "done"].clone() | ||
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True) | |
# Mark the last step of every sequence as done. | |
done[(slice(None),) * (ndim - 1) + (-1,)].fill_(True) |
idx_ = (slice(None),) * (ndim - 1) + ( | ||
slice(data.shape[ndim - 1], None), | ||
data_copy = data.copy() | ||
done = data_copy["next", "done"].clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll admit having quite a hard time following the logic that lays below, but more importantly why it is needed. I think a few comments could really help the next reader :)
test/test_cost.py
Outdated
for p in recurrent_module.parameters(): | ||
p.data *= 1 + torch.randn_like(p.data) / 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious as to why overwriting the RNN parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did that to make sure we had enough power to detect slight changes in behavior (ie that every item that could be numerically different than the same computed with the opposite shifted param would be), but in practice it's less useful now
test/test_cost.py
Outdated
recurrent_module, | ||
Mod(mlp_policy, in_keys=["intermediate"], out_keys=["action"]), | ||
) | ||
# value_net.select_out_keys("state_value") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# value_net.select_out_keys("state_value") |
# value_net.select_out_keys("state_value") | ||
env = env.append_transform(recurrent_module.make_tensordict_primer()) | ||
vals = env.rollout(1000, policy_net, break_when_any_done=False) | ||
value_net(vals.copy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this call to the value net is not necessary?
value_net(vals.copy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's there for init purposes
test/test_cost.py
Outdated
# TODO: where to put this? | ||
vals["next", "is_init"] = vals["is_init"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we solve this TODO before merging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! I think this is something we can safely ask GAE to do for us
_hidden0_in = hidden0_in[..., 0, :, :] | ||
_hidden1_in = hidden1_in[..., 0, :, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is interesting. I'm guessing the shape of hiddenN_in
is (batch, sequence, num_lstm_layers, hidden_size)
.
Were we before not using the first hidden state of a sequence, but the first hidden layer in the LSTM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's more an aesthetic change than anything. Just after we swap dims based on negative indices, so using negative indices all the way through helps clarify things. In both cases, we're indexing the 0th element of the sequence.
Stack from ghstack (oldest at bottom):