-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement the ModernBert
model
#459
Conversation
FYI, there is now https://huggingface.co/nomic-ai/modernbert-embed-base. |
FYI, running the nomic/modernbert-base model yields an error as the safetensors are not under model.embeddings.* but embeddings.* |
thanks! I've just worked on supporting also appreciate your offer of the GPU support! currently, I'm kinda a lot on my plate so, I'll reach out later to you :) anyway, thanks again for your support
|
is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ? |
It looks like it uses custom tokenizing logic that uses multiple tokenizers and determines one tokenizer on the fly, depending on the input text. the architecture in and of itself is supported, but it would be hard to use with TEI I guess. |
I have found another fine-tune from them which is specifically for German (DE), https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE-DE/blob/main/config.json but i am having the issue as their config says pad_token_id null. I tried to follow through your implementation but this is where i stuck where the model is expecting a pad_token_id |
it seems like it uses
You should fill in missing configs with proper values in |
Thank you. I was able to run nomicai/modernbert-base following your instruction. The other fine-tuned one i mentioned already had some changes as you suggested. but seems still struggling for longer text (more than 128 tokens). I wrote to them directly. |
great to hear! If you encounter an issue, --- updated I just fixed the bug 63c4224, could you please test with the latest commit? |
Thanks a lot! it fixed the bug. I can confirm, I no longer see that issue and run long text embedding. |
Is this on the path of getting merged ? |
@alvarobartt @regisss Any chance we can merge this soon? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've noticed with @Narsil a difference when using AlibabaNLP/gte-modernbert-base
with sentence-transformers, the embeddings don't match
@@ -4,7 +4,7 @@ use candle_nn::VarBuilder; | |||
#[derive(Debug)] | |||
pub struct LayerNorm { | |||
weight: Tensor, | |||
bias: Tensor, | |||
bias: Option<Tensor>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps recreate a new LayerNorm class that has no bias rather than changing this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this. At first, I believed it'd be great to recreate a new LayerNorm struct after the maintainers confirm it regarding their preferences and usages.
if it's okay, I'll add a new LayerNorm named with LayerNormNoBias
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added 41852e4
let mut new_qkv_shape = qkv.dims().to_vec(); | ||
new_qkv_shape.pop(); | ||
new_qkv_shape.push(self.num_attention_heads * 3); | ||
new_qkv_shape.push(self.attention_head_size); | ||
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to transpose in load
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, the variable qkv
is determined on the fly, so it cannot be transposed in load()
.
Apart from this, maybe we could refactor this part more cleanly like below. how about this?
let qkv = qkv
.reshape((
b,
seq_len,
3,
self.num_attention_heads,
self.attention_head_size,
))?
.permute((2, 0, 3, 1, 4))?;
let q = qkv.get(0)?;
let k = qkv.get(1)?;
let v = qkv.get(2)?;
let query_layer = &qkv[0].contiguous()?; | ||
let key_layer = &qkv[1].contiguous()?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Contiguous copies the underlying tensor, could this also be handled in load
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new(); | ||
|
||
for use_local_attention in [true, false] { | ||
let rope_theta = if use_local_attention { | ||
config.local_rope_theta | ||
} else { | ||
config.global_rope_theta | ||
}; | ||
|
||
let max_position_embeddings = if use_local_attention { | ||
config.max_position_embeddings | ||
} else { | ||
config.local_attention | ||
}; | ||
|
||
let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?; | ||
|
||
let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?; | ||
|
||
rotary_cache.insert(use_local_attention, (cos, sin)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HashMap<bool, (Tensor, Tensor)>
-> [(Tensor, Tensor); 2]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also agree tuple is enough to save that. I'll refactor this too. thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored dc3aa0a
While using a tuple can be a great option, on second thought, defining them with intuitive names could lead to more readable code, so I created a new struct RotarayEmbedding
, and defined global rotary embed
and local rotary embed
each!
perhaps we can refactor another rotary usage with this struct.
-- updated 3/31
Due to some problems, I rolled back the RotaryEmbedding
struct and defined global/local rotary cache
with a tuple each. you can check here 43f2322!
&self.device, | ||
)?; | ||
|
||
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored dc3aa0a
#[test] | ||
#[serial_test::serial] | ||
fn test_mini_pooled_raw() -> Result<()> { | ||
let model_root = download_artifacts("sentence-transformers/all-mpnet-base-v2", None)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be modernbert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for catching this. gonna fix this too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed 57a04fc
Thanks for your review! I'll look into the I'll work on your review and ping you afterward. Thank you for taking the time to review it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @kozistr, just added some minor nits (I'll try to explore why it's not working on MPS) 🤗
backends/candle/src/lib.rs
Outdated
(Config::ModernBert(config), _) => { | ||
tracing::info!("Starting ModernBert model on {:?}", device); | ||
Ok(Box::new( | ||
ModernBertModel::load(vb, &config, model_type).s()?, | ||
)) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that's not supported on Metal due to Metal strided to_dtype F16 U8 not implemented
could you include a check to emit a BackendError
if the device is Device::Metal
until solved? Thanks in advance!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok it seems that some kernels are not still within the current candle
version cc @Narsil, see huggingface/candle@6eea45a (in this case the cast_f16_u8_strided
one is missing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this error. okay I'll exclude MPS device support until solved.
excluded d2233e5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alvarobartt could you please test on the Metal device again with the latest commit? I refactored the get_local_attention_mask
function, which may resolve the above issue F16 U8 not implemented
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kozistr, so I disabled Device::Metal
error and managed to run it with dtype=float32
, but it failed with the default precision dtype=float16
due to dtype mismatch in add, lhs: F16, rhs: F32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If enabling MPS support with the current candle
version is too much of a hustle, we can maybe enforce the dtype to be float32 only, and raise an error otherwise. Anyway I expect most of the usage to be CPU or GPU, so let's just document that and move forward 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just found that fp16 with any devices got an error dtype mismatch in add, lhs: F16, rhs: F32
and it's because of attention_mask
, which I mistakenly set f32 min even if the dtype is fp16.
From now on, fp16 with CPU works (== runs w/o error)!
$ text-embeddings-router --model-id ./gte-modernbert-base --pooling cls --port 8888 --dtype float16 --auto-truncate
[[0.013360826,-0.056719333,-0.016282361...,0.016844664,0.05989757,0.010225371]]
fixed ceccbca
Apart from this, I've tested with the AlibabaNLP/gte-modernbert-base
model (fp16 weight), but failed to get identical results with sentence-transformers
, while nomic-ai/modernbert-embed-base
and answerdotai/ModernBert-base, large
work.
So, how about disabling both MPS and fp16 support for now? (`if device == MPS or dtype == fp16),
or maybe if it works on MPS device based on ceccbca, we can just drop fp16 support for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checked that nomic-ai/modernbert-embed-base
, answerdotai/ModernBERT-base
, and Alibaba-NLP/gte-modernbert-base
are working
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
I have differences with all models under tests, how do you check the model outputs ? We've migrated to candle 0.8 on main, rebasing should be ok, but I can help with that. |
hi. Sorry for the confusion. I made a mistake while refactoring the rotary embedding part dc3aa0a, so I've fixed the bug. And I finally checked the output looks correct 43f2322. Please let me know if there's still an issue or anything!
Here's my code.
It'd be great to help with rebasing! Could you please help with that? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey thanks a lot for this contribution.
I still see wide differences between sentence-transformers and this implementation here.
I'm not sure if I'm making any mistakes in my comparison but I expect
# Requires transformers>=4.48.0
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
input_texts = [
"what is the capital of China?",
"how to implement quick sort in python?",
"Beijing",
"sorting algorithms"
]
model_path = "Alibaba-NLP/gte-modernbert-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)
# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=8192, padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
embeddings = outputs.last_hidden_state[:, 0]
# (Optionally) normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
To provide the correct embeddings (there is no normalization normally I'm providing the normalized output in case).
Also this implementation is lacking the Flash version.
However, in the interest of getting things done, I will merge this already, and attempt to fix the implementation in a follow-up in order to get things moving on this front.
Thanks a lot for this PR !
@Narsil I modified & ran one of the above scripts, and I'm able to get matching results: Modified script for gte-modernbert-baseimport json
import numpy as np
import requests
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, ModernBertModel
from torch.nn.functional import normalize
sentences = [
'What is Deep Learning?',
'Deep Learning is...',
'What is Deep Learning?',
]
model_id = 'Alibaba-NLP/gte-modernbert-base'
# Sentence-Transformers
st_model = SentenceTransformer(model_id)
st_model.eval()
# Transformers
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ModernBertModel.from_pretrained(model_id)
model.eval()
with torch.inference_mode():
st_results = st_model.encode(sentences, normalize_embeddings=True)
tokens = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.inference_mode():
embeddings = model(**tokens)[0]
tf_results = normalize(embeddings[:, 0, :], p=2.0).numpy()
# $ text-embeddings-router --model-id Alibaba-NLP/gte-modernbert-base --pooling cls --port 8080 --dtype float32 --auto-truncate
tei_results = np.asarray(
requests.post(
'http://127.0.0.1:8080/embed',
data=json.dumps({'inputs': sentences}),
headers={'Content-type': 'application/json'},
).json()
)
print(st_results)
# [[ 0.01344569 -0.05683083 -0.01630153 ... 0.01694537 0.05998897
# 0.01016068]
# [ 0.01036755 -0.06500905 -0.03710588 ... -0.00754553 0.05179105
# 0.00935954]
# [ 0.01344569 -0.05683083 -0.01630153 ... 0.01694537 0.05998897
# 0.01016068]]
print(tf_results)
# [[ 0.01344565 -0.05683074 -0.01630154 ... 0.01694529 0.05998905
# 0.01016066]
# [ 0.01036758 -0.06500904 -0.03710586 ... -0.0075455 0.05179105
# 0.00935955]
# [ 0.0134457 -0.0568308 -0.01630155 ... 0.01694538 0.05998901
# 0.01016067]]
print(tei_results)
# [[ 0.01344578 -0.05683083 -0.01630141 ... 0.01694547 0.05998898
# 0.01016063]
# [ 0.01036761 -0.06500905 -0.03710582 ... -0.00754556 0.05179107
# 0.00935959]
# [ 0.01344578 -0.05683083 -0.01630141 ... 0.01694547 0.05998898
# 0.01016063]]
|
I also have everything matching after merging. I have no clue what I botched in my testing. I manually rebased so maybe I screwed something there. It wasn't working on Cuda because of I'm following this up, but at least the simple implementation does work, thanks a lot ! |
What does this PR do?
Close #457
tokenizer
crate from0.19.1
to0.21.0
to address aModernBert
tokenizer issue.ModernBert
modeland MPS.ModernBert
uses local attention. however, I'm unfamiliar withcandle_flash_attn
and don't have any GPU to test FA2 w/ local attn, so theFlashModernBert
implementation remains unsupported at this time.ModernBert
nomic-ai/modernbert-embed-base
answerdotai/ModernBert-base/large
Log
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@OlivierDehaene OR @Narsil