Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure dtype consistency in Pooling forward method #2492

Merged
merged 4 commits into from
Feb 21, 2024

Conversation

EliasKassapis
Copy link
Contributor

Adjusted Pooling module's forward function to initialize new tensors with the same dtype as input tensors, fixing dtype mismatch errors in mixed precision settings (e.g., model.half()). This change prevents errors arising from hard-coded .float() usage, enabling seamless operation across different dtype environments.

Adjusted Pooling module's forward function to initialize new tensors with the same dtype as input tensors, fixing dtype mismatch errors in mixed precision settings (e.g., model.half()). This change prevents errors arising from hard-coded .float() usage, enabling seamless operation across different dtype environments.
@tomaarsen
Copy link
Collaborator

Hello!

I think this makes sense at a glance! And you mention that model.half() causes problems currently, is that during inference? Please let me know or provide a simple reproduction snippet, so I can effectively verify if this PR solves the problem.

  • Tom Aarsen

@EliasKassapis
Copy link
Contributor Author

Hey Tom, thanks for replying so quickly. Before the fix, this part of my code led to a runtime error:

model = SentenceTransformer("distiluse-base-multilingual-cased-v2").to(device) 
model.half()
model.encode(input_text, show_progress_bar=False)

Error: RuntimeError: mat1 and mat2 must have the same dtype, but got Float and Half

This error originates from the model.encode(...) line, from the models' Dense layer (with param dtype float16 after invoking model.half()) which follows the Pooling module. The latter returned a dict where key sentence_embedding contained a tensor of dtype float32 due to hardcoded .float() tensor operations in the Pooling module.

This commit resolves this mismatch by replacing the hardcoded .float() with .to(token_embeddings.dtype), thereby maintaining dtype consistency within the Pooling module's forward function

@tomaarsen
Copy link
Collaborator

Very clear! I can reproduce this, too. I've ran make style to satisfy the code quality check & I've added a very simple test case. Thanks for this work! I'll merge it when the CI is green :)

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 20056c6 into UKPLab:master Feb 21, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants