Skip to content

Commit

Permalink
update runtime error message for minibatch (#8243)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Oct 30, 2024
1 parent dc20b2d commit 3efe1eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,7 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self):
mesh, ('data', None, None, None), minibatch=True))
with self.assertRaisesRegex(
RuntimeError,
"When minibatch is configured, batch dimension of the tensor must be divisible by local runtime device count*"
"When minibatch is configured, the per-host batch size must be divisible by local runtime device count. Per host input data shape *"
):
data, _ = iter(train_device_loader).__next__()

Expand Down
7 changes: 3 additions & 4 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,10 +1309,9 @@ def convert_fn(tensors):
if sharding and tensor.dim() > 0 and (tensor.size()[0] %
local_runtime_device_count) != 0:
raise RuntimeError(
"When minibatch is configured, batch dimension of the tensor " +
"must be divisible by local runtime device count.input data shape "
+
f"={tensor.size()}, local_runtime_device_count = {local_runtime_device_count}"
"When minibatch is configured, the per-host batch size must be divisible "
+ "by local runtime device count. Per host input data shape " +
f"= {tensor.size()}, local_runtime_device_count = {local_runtime_device_count}"
)

xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,
Expand Down

0 comments on commit 3efe1eb

Please sign in to comment.