-
Notifications
You must be signed in to change notification settings - Fork 33
pixtral SFT #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pixtral SFT #296
Conversation
# not sure of assignment, reading flag to indicate whether preference loss-masking spans are present | ||
self._has_preference_spans = struct.unpack("<B", stream.read(1))[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's already read above, why read it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's another flag written after images, this is to read that, but not sure what the assignment should be
Fast-LLM/fast_llm/data/dataset/gpt/memmap.py
Lines 407 to 412 in 2f85615
# Placeholder flag for preference spans | |
idx_stream.write(struct.pack("<B", 0)) | |
# Flag to indicate whether images are present | |
idx_stream.write(struct.pack("<B", 1 if total_images > 0 else 0)) | |
# Flag to indicate whether preference loss-masking spans are present | |
idx_stream.write(struct.pack("<B", 1 if chosen_spans.size > 0 and rejected_spans.size > 0 else 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah that order should be flipped, the chosen_spans byte should be before total_images, i'll fix it in my branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would break files with version==3 right?
It seems to me that we should rather fix the order in which those flags are dumped in the idx file below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@RaymondLi0 yes, I'm planning to fix it in #227
fast_llm/data/dataset/gpt/memmap.py
Outdated
total_pixels_needed = sum( | ||
length[0] * length[1] * 3 for length in self._image_lengths[idx] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_pixels_needed = sum( | |
length[0] * length[1] * 3 for length in self._image_lengths[idx] | |
) | |
total_pixels_needed = self._image_lengths[idx].prod(initial=3, axis=1).sum() |
fast_llm/data/dataset/gpt/memmap.py
Outdated
offset=self._pointers[idx] + self._document_sizes[idx] * np.dtype(self._dtype).itemsize, | ||
) | ||
images = [] | ||
start = 0 | ||
for image_length in self._image_lengths[idx]: | ||
n_pixels = image_length.prod(initial=3) | ||
n_pixels = image_length[0] * image_length[1] * 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can leave it as using .prod?
fast_llm/data/dataset/gpt/sampled.py
Outdated
@@ -549,7 +549,7 @@ def __getitem__(self, index: int) -> typing.Any: | |||
use_loss_masking_spans=self._parameters.use_loss_masking_spans, | |||
) | |||
start_pos = 0 | |||
if sample.image_positions: | |||
if sample.image_positions is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a bool has_images = bool(sample.image_positions)
and use it below as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think bool(sample.image_positions)
will throw ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
when there are more than one image_positions: bool(np.array([2, 3]))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or simply has_images = True if sample.image_positions else False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if sample.image_positions
would throw the same error, changed it to has_image_positions = sample.image_positions is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
β¨ Description
Some bug fixes for Image+Text SFTs
π Type of change
Select all that apply:
π Changes
List the key changes introduced in this PR:
β Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
π Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
ποΈ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.