-
Notifications
You must be signed in to change notification settings - Fork 99
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When using bool data, those are transformed to 0-dim tensors internally. As a result, auto_batch_size_
can't infer the batch size, and also setting the batch size raises an exception.
To Reproduce
from tensordict import TensorDict
# Passing "string" argument (non tensor)
td = TensorDict(no_bs_arg="True")
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([])
td["tt"] = torch.rand(2, 3)
assert td.batch_size == torch.Size([])
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([2])
# Passing "bool" argument (still a non-tensor but transformed to tensor internally)
td = TensorDict(no_bs_arg=True)
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([])
td["tt"] = torch.rand(2, 3)
assert td.batch_size == torch.Size([])
td.auto_batch_size_(1)
assert td.batch_size == torch.Size([])
try:
td.batch_size = torch.Size([2])
except RuntimeError as e:
assert (
str(e)
== "the tensor no_bs_arg has shape torch.Size([]) which is incompatible with the batch-size torch.Size([2])."
)
Expected behavior
Bool should be handled like string and other non-tensor data.
Reason and Possible fixes
The reason is that bool arguments are internally transformed into tensor data. While I understand that using tensors as an internal representation might be more efficient, maybe we should ignore tensors with .ndim == 0
in automatically calculating the batch size and also in _check_new_batch_size
.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working