Skip to content

Concatenate str support for IterableDataset #3686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

TitouanCh
Copy link

@TitouanCh TitouanCh commented Jul 18, 2025

What does this PR do?

This PR changes concatenate so it doesn't crash when batches contain non-tensor values like strings.
Right now it seems to throw a TypeError if you had anything other than tensors, for example, string labels.
This is pretty useful for IterableDataset. I also added a new test.

Fixes #3624 #1878

Proposal

When concatenating lists, check if the first element is a str, if so, concat as strings as a flat python list instead of tensor (I could also check all elements if you would like to avoid cases like ["test", 0] + ["test", "test] = ["test", 0, "test", "test"]).
Changed the logic to check all elements when concatenating lists. If all elements are strings, they are concatenated as a flat Python list.
If this is accepted, this can easily be adapted to support more types if needed.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@BenjaminBossan @SunMarc @zach-huggingface

Minimal example script showing the problem

import torch
import numpy as np
from collections.abc import Mapping
from accelerate.utils import concatenate


def test_batches():
    print("=== Test 1: Tensor batches ===")
    batch1 = {
        "x": torch.rand(4, 1),
        "y": torch.from_numpy(
            np.array(
                [[1.0, 2.0, 3.0]] * 4,
                dtype=np.float32,
            )
        ),
    }

    batch2 = {
        "x": torch.rand(4, 1),
        "y": torch.from_numpy(
            np.array(
                [[1.0, 2.0, 3.0]] * 4,
                dtype=np.float32,
            )
        ),
    }

    batch = concatenate([batch1, batch2], dim=0)

    print(batch)
    print("x shape:", batch["x"].shape)  # Should be (8, 1)
    print("y shape:", batch["y"].shape)  # Should be (8, 3)

    print("\n=== Test 2: Mixed types (with lists) ===")
    batch1 = {"x": torch.rand(4, 1), "animals": ["dog", "cat", "baby", "penguin"]}
    batch2 = {
        "x": torch.rand(4, 1),
        "animals": ["koala", "samurai", "iguana", "rabbit"],
    }

    batch = concatenate([batch1, batch2], dim=0)

    print(batch)
    print("x shape:", batch["x"].shape)  # Should be (8, 1)
    print("animals:", batch["animals"])  # Should be batch1["animals"] + batch2["animals"]


if __name__ == "__main__":
    test_batches()
>> "src/accelerate/utils/operations.py", line 637 in concatenate
>> TypeError: Can only concatenate tensors but got <class 'str'>

Example with IterableDataset

import torch
from torch.utils.data import IterableDataset, DataLoader
import random
from accelerate import Accelerator


class SyntheticIterableDataset(IterableDataset):
    def __init__(self, num_samples, input_dim, vocab):
        self.num_samples = num_samples
        self.input_dim = input_dim
        self.vocab = vocab

    def __iter__(self):
        for _ in range(self.num_samples):
            x = torch.randint(0, 256, (self.input_dim,), dtype=torch.float32)
            y_str = random.choice(self.vocab)  # str attribute
            yield {"x": x, "y": y_str}


def custom_collate(batch):
    collated = {}
    for sample in batch:
        for k, v in sample.items():
            if k not in collated:
                collated[k] = []
            collated[k].append(v)
    for k in collated:
        if torch.is_tensor(collated[k][0]):
            collated[k] = torch.stack(collated[k])
    return collated


if __name__ == "__main__":
    # Setup model
    vocab = ["cat", "dog", "snake"]
    vocab_to_idx = {label: i for i, label in enumerate(vocab)}
    num_classes = len(vocab)
    input_dim = 1000 + num_classes

    dataset = SyntheticIterableDataset(1000, 1000, vocab=vocab)
    dataloader = DataLoader(dataset, batch_size=32, collate_fn=custom_collate)

    model = torch.nn.Linear(input_dim, 10)  # Dummy linear model
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    accelerator = Accelerator()
    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
    device = accelerator.device

    for epoch in range(10):
        for batch in dataloader:
            print(batch)
            x = batch["x"].to(device)

            # y to one-hot
            y_str_list = batch["y"]
            y_idx = torch.tensor([vocab_to_idx[y] for y in y_str_list], device=device)
            y_onehot = torch.nn.functional.one_hot(
                y_idx, num_classes=num_classes
            ).float()

            model_input = torch.cat([x, y_onehot], dim=1)

            output = model(model_input)
            loss = output.sum()

            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()

        print(f"Epoch {epoch} complete.")
>> "src/accelerate/utils/operations.py", line 637 in concatenate
>> TypeError: Can only concatenate tensors but got <class 'str'>

Discussion

Could be extended to support constants? Like

batch1 = {
    "const": True,
    "const_num": 1,
    "x": torch.tensor([
        [1, 2, 3],
        [2, 3, 4]
    ]),
    "y": ["c", "d"],
}

batch2 = {
    "const": True,
    "const_num": 1,
    "x": torch.tensor([
        [1, 2, 3],
        [1, 4, 5]
    ]),
    "y": ["a", "b"],
}

batch_concat = {
    "const": True,
    "const_num": 1,
    "x": torch.vstack([batch1["x"], batch2["x"]]),
    "y": batch1["y"] + batch2["y"],
}

@TitouanCh TitouanCh marked this pull request as ready for review July 18, 2025 13:29
@TitouanCh TitouanCh marked this pull request as draft July 18, 2025 13:31
@TitouanCh TitouanCh force-pushed the concatenate-str-support branch from 47cca6c to 1a73970 Compare July 18, 2025 13:44
@TitouanCh TitouanCh marked this pull request as ready for review July 18, 2025 17:35
@TitouanCh TitouanCh force-pushed the concatenate-str-support branch from 9d49356 to 032c7d2 Compare July 22, 2025 13:07
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment !

Comment on lines 618 to 622
first_inner = data[0][0] if len(data[0]) > 0 else None

if isinstance(first_inner, str):
return honor_type(data[0], [item for sublist in data for item in sublist])
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't modify that, we should check elif isinstance(data[0], str) and do the respective changes there. Try to add for test case also involving tuple and dictionaries. Also we should check that no tensors is being concatenated with str

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!
Instead of checking for str directly, I check whether it's a list of strings, this made the logic simpler.
Let me know if that works for you, and I'm happy to explore another approach if needed.

@TitouanCh TitouanCh requested a review from SunMarc July 23, 2025 11:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

dataloader with IterableDataset cannot work after accelerator.prepare()
2 participants