-
Notifications
You must be signed in to change notification settings - Fork 95
Dynamic batch size for openai embedding models #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements dynamic batching for OpenAI embedding models by grouping input strings based on their token counts, rather than using a fixed batch size.
- Added a new helper function (_create_token_aware_batches) that creates batches based on token limits.
- Updated embed_strings_without_late_chunking to use the new batching logic for OpenAI models.
- Applied formatting improvements throughout the file.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
src/raglite/_embed.py | Introduced dynamic, token-aware batching and improved multi-line formatting in various funcs. |
pyproject.toml | Added the tiktoken dependency required for token-based batching. |
# If adding this string exceeds limit, start new batch | ||
if current_tokens + tokens > max_tokens and current_batch: | ||
batches.append(current_batch) | ||
current_batch = [string] | ||
current_tokens = tokens | ||
else: | ||
current_batch.append(string) | ||
current_tokens += tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a check to handle cases where an individual string's token count exceeds the maximum allowed tokens, to ensure the batching logic gracefully handles such edge cases.
# If adding this string exceeds limit, start new batch | |
if current_tokens + tokens > max_tokens and current_batch: | |
batches.append(current_batch) | |
current_batch = [string] | |
current_tokens = tokens | |
else: | |
current_batch.append(string) | |
current_tokens += tokens | |
# Handle strings that exceed the max token limit | |
if tokens > max_tokens: | |
# Split the string into smaller chunks | |
start = 0 | |
while start < len(string): | |
chunk = string[start:start + max_tokens] | |
chunk_tokens = len(encoding.encode(chunk)) | |
if current_tokens + chunk_tokens > max_tokens and current_batch: | |
batches.append(current_batch) | |
current_batch = [] | |
current_tokens = 0 | |
current_batch.append(chunk) | |
current_tokens += chunk_tokens | |
start += len(chunk) | |
else: | |
# If adding this string exceeds limit, start new batch | |
if current_tokens + tokens > max_tokens and current_batch: | |
batches.append(current_batch) | |
current_batch = [string] | |
current_tokens = tokens | |
else: | |
current_batch.append(string) | |
current_tokens += tokens |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lsorber WDYT?
Imo, this is impossible to have group of sentences (chunklets) higher than 8192 tokens.
Modification include these functions: The remainder is related to linting. |
batches = _create_token_aware_batches(strings, config.embedder) | ||
else: | ||
# Original fixed batching for non-OpenAI models | ||
batch_size = 96 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should never be more than 96 because if using an Azure AI foundry embedding model (such as cohere embed4), the maximum batch size allowed by AZURE ai foundry is 96.
Closing after discussion with @ThomasDelsart led to the conclusion that this PR is not needed. |
Answer this issue: #151
Please read context from the issue.