Skip to content

WIP: multimodal support #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 83 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
7709e65
WIP: multimodal support
sohamparikh Apr 8, 2025
0db2bd2
rough idea for memmap
sohamparikh Apr 9, 2025
0d89f68
faster image size reading
sohamparikh Apr 15, 2025
3866a53
solidify prepare
sohamparikh Apr 21, 2025
8413983
wip
sohamparikh Apr 24, 2025
6521e41
vision model
sohamparikh Apr 24, 2025
daf586f
wip
sohamparikh Apr 24, 2025
ef4488d
wip
sohamparikh Apr 25, 2025
6d9d595
missing files
sohamparikh Apr 28, 2025
6cb8f5d
make it work, barely
sohamparikh Apr 30, 2025
5761a2d
fix
sohamparikh Apr 30, 2025
d45d600
fixes
sohamparikh May 1, 2025
74a99b8
changes
sohamparikh May 6, 2025
99ad5d9
patches and fixes
sohamparikh May 7, 2025
bcb557a
fix dependency
sohamparikh May 7, 2025
a6f5364
remove for testing
sohamparikh May 7, 2025
73b431b
mising
sohamparikh May 7, 2025
6d65676
fix
sohamparikh May 8, 2025
46aefc1
Merge branch 'main' into soham/pixtral-support
sohamparikh May 9, 2025
66e7081
fixes
sohamparikh May 9, 2025
7f86a7f
fix
sohamparikh May 12, 2025
3a8a99d
more fixes after merge
sohamparikh May 12, 2025
d16284e
conv cleanup
sohamparikh May 12, 2025
b3134aa
more conv cleanup
sohamparikh May 12, 2025
c8aa66e
images + loss-masks
sohamparikh May 13, 2025
0baae59
minor fixes
sohamparikh May 13, 2025
48855be
cleanup
sohamparikh May 13, 2025
f35e003
cleanup
sohamparikh May 13, 2025
4eb34cb
cleanup
sohamparikh May 13, 2025
ebb9e27
cleanup
sohamparikh May 13, 2025
51098ef
fix
sohamparikh May 13, 2025
60b87fa
prepare cleanup
sohamparikh May 13, 2025
f8a5532
slightly better conversion
sohamparikh May 13, 2025
490651e
cleanup, sequence parallelism
sohamparikh May 14, 2025
24e1b83
fix conv
sohamparikh May 14, 2025
0f1612a
wip fixes
sohamparikh May 14, 2025
2e48c5f
fix
sohamparikh May 14, 2025
d529d37
fix image position
sohamparikh May 17, 2025
3c22dda
cleanup
sohamparikh May 17, 2025
f0c8d83
cleanup
sohamparikh May 20, 2025
ca33ee8
cleaner, extensible multimodal config
sohamparikh May 21, 2025
f3a4a74
cleanup
sohamparikh May 21, 2025
3b955b1
fixes for pixtral
sohamparikh May 21, 2025
49daf58
model fixes
sohamparikh May 21, 2025
b5ed9f4
more cleanup
sohamparikh May 22, 2025
dc888c8
image break token in sampling
sohamparikh May 22, 2025
af3e2db
minor fixes
sohamparikh May 23, 2025
6d56be0
fix img break
sohamparikh May 24, 2025
ce91646
fixes
sohamparikh May 27, 2025
204b3e9
fix image embeddings offset
sohamparikh May 28, 2025
fd08eac
heterogeneous data fixes
sohamparikh May 29, 2025
1e3652a
convert to rgb
sohamparikh May 29, 2025
2aabf35
fix sequence parallel image patches
sohamparikh May 30, 2025
b6d4858
fixes
sohamparikh May 31, 2025
25a650b
no compile for embeddings
sohamparikh May 31, 2025
c904da5
fix sampling
sohamparikh Jun 1, 2025
7a4701c
sampling and preprocessing bugs
sohamparikh Jun 2, 2025
067f901
speed up sampling
sohamparikh Jun 2, 2025
f24325e
cap image size reduction
sohamparikh Jun 2, 2025
0f37664
fix span offset with images
sohamparikh Jun 2, 2025
ff8fecc
fix span offset with images
sohamparikh Jun 2, 2025
c663cbb
move image logic to sampled
sohamparikh Jun 3, 2025
f52f02b
cleanup
sohamparikh Jun 3, 2025
5436357
merge main
sohamparikh Jun 4, 2025
02f6d8f
cleanup
sohamparikh Jun 5, 2025
6843129
jpeg dependency
sohamparikh Jun 5, 2025
b94b1ee
install libjpeg-dev in gh actions
sohamparikh Jun 5, 2025
9e4f14f
fix sampling test
sohamparikh Jun 5, 2025
d1c804f
fix
sohamparikh Jun 6, 2025
75d64a6
fix data cache reloading
sohamparikh Jun 9, 2025
cba6986
fix tokenization
sohamparikh Jun 9, 2025
275fefa
pixtral SFT (#296)
shruthan Jun 11, 2025
605cc7f
review comments
sohamparikh Jun 11, 2025
06aa740
simplified tokenization with spans
sohamparikh Jun 12, 2025
30e3d34
Update fast_llm/data/preparator/gpt_memmap/prepare.py
sohamparikh Jun 12, 2025
c1aa709
rename
sohamparikh Jun 12, 2025
0ada42b
Merge branch 'soham/pixtral-support' of github.com:ServiceNow/Fast-LL…
sohamparikh Jun 12, 2025
4e7afd8
merge main
sohamparikh Jun 12, 2025
8e106f7
fix conversion
sohamparikh Jun 12, 2025
080dcb5
fix sequence lengths, parallel conv
sohamparikh Jun 16, 2025
f186868
minor
sohamparikh Jun 16, 2025
6b9ea2e
fix image at beginning
sohamparikh Jun 16, 2025
ad18ea1
pixtral fix conversion (#315)
RaymondLi0 Jun 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:

- name: Install dependencies
run: |
sudo apt install libjpeg-dev
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
restore-keys: |
mkdocs-material-
- run: |
sudo apt install libjpeg-dev
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
Expand Down Expand Up @@ -56,6 +57,7 @@ jobs:
restore-keys: |
mkdocs-material-
- run: |
sudo apt install libjpeg-dev
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \
Expand Down
18 changes: 18 additions & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
images: list[torch.Tensor] | None = None
image_positions: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None

Expand All @@ -49,12 +51,28 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
has_images = False
batch_images = []
for sample in batch:
if sample.images is not None:
batch_images.append([torch.from_numpy(image) for image in sample.images])
has_images = True
else:
batch_images.append([])
batch_image_positions = []
for sample in batch:
if sample.image_positions is not None:
batch_image_positions.append(torch.from_numpy(sample.image_positions))
else:
batch_image_positions.append([])
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
images=batch_images if has_images else None,
image_positions=batch_image_positions if has_images else None,
)


Expand Down
13 changes: 12 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class GPTSamplingParameters(SamplingParameters):
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
patch_size: int | None = None
max_image_size: int | None = None
image_break_token: int | None = None
image_end_token: int | None = None
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1
Expand Down Expand Up @@ -142,11 +146,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
desc="Expected number of tokens in the dataset.",
hint=FieldHint.optional,
)
num_pixels: int | None = Field(
default=None,
desc="Expected number of pixels in the dataset.",
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
return GPTMemmapDataset(
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels
)


@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"})
Expand Down
15 changes: 14 additions & 1 deletion fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset":
else GPTSampledIndexedDataset(self, sampling)
)

@property
@abc.abstractmethod
def has_images(self) -> bool:
"""
Whether the dataset contains images.
This is used to determine whether to use image-related fields in the sampled data.
"""


class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset):
"""
Expand All @@ -44,11 +52,16 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return self._dataset.get_document_sizes()[self._begin : self._end]
doc_sizes, im_sizes = self._dataset.get_document_sizes()
return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end] if im_sizes else []

def get_document_size(self, index: int) -> int:
return self._dataset.get_document_size(self._begin + index)

@property
def has_images(self) -> bool:
return self._dataset.has_images


class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset
Expand Down
Loading