Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the ModernBert model #459

Merged
merged 25 commits into from
Apr 4, 2025
Merged

Conversation

kozistr
Copy link
Contributor

@kozistr kozistr commented Dec 25, 2024

What does this PR do?

Close #457

  • upgrade the tokenizer crate from 0.19.1 to 0.21.0 to address a ModernBert tokenizer issue.
  • implement ModernBert model
    • it may work on CPU, CUDA (w/o FA), and MPS.
    • ModernBert uses local attention. however, I'm unfamiliar with candle_flash_attn and don't have any GPU to test FA2 w/ local attn, so the FlashModernBert implementation remains unsupported at this time.
  • implement a classification head for ModernBert
  • tested
    • nomic-ai/modernbert-embed-base
    • answerdotai/ModernBert-base/large

Log

$ ./target/release/text-embeddings-router --model-id ./ModernBERT-base/ --port 8888 --pooling cls --dtype float32
2024-12-25T07:34:46.753673Z  INFO text_embeddings_router: router/src/main.rs:175: Args { model_id: "./Mod*******-*ase/", revision: None, tokenization_workers: None, dtype: Some(Float32), pooling: Some(Cls), max_concurrent_requests: 512, max_batch_tokens: 16384, max_batch_requests: None, max_client_batch_size: 32, auto_truncate: false, default_prompt_name: None, default_prompt: None, hf_api_token: None, hostname: "0.0.0.0", port: 8888, uds_path: "/tmp/text-embeddings-inference-server", huggingface_hub_cache: None, payload_limit: 2000000, api_key: None, json_output: false, otlp_endpoint: None, otlp_service_name: "text-embeddings-inference.server", cors_allow_origin: None }
2024-12-25T07:34:46.817444Z  WARN text_embeddings_router: router/src/lib.rs:184: Could not find a Sentence Transformers config
2024-12-25T07:34:46.817472Z  INFO text_embeddings_router: router/src/lib.rs:188: Maximum number of tokens per request: 8192
2024-12-25T07:34:46.817622Z  INFO text_embeddings_core::tokenization: core/src/tokenization.rs:28: Starting 8 tokenization workers
2024-12-25T07:34:46.883933Z  INFO text_embeddings_router: router/src/lib.rs:230: Starting model backend
2024-12-25T07:34:46.884247Z  INFO text_embeddings_backend_candle: backends/candle/src/lib.rs:239: Starting ModernBert model on Cpu
2024-12-25T07:34:47.138974Z  WARN text_embeddings_router: router/src/lib.rs:258: Backend does not support a batch size > 4
2024-12-25T07:34:47.139002Z  WARN text_embeddings_router: router/src/lib.rs:259: forcing `max_batch_requests=4`
2024-12-25T07:34:47.139930Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1812: Starting HTTP server: 0.0.0.0:8888
2024-12-25T07:34:47.139955Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1813: Ready
2024-12-25T07:34:52.701893Z  INFO embed{total_time="115.486302ms" tokenization_time="322.4µs" queue_time="363.6µs" inference_time="114.688702ms"}: text_embeddings_router::http::server: router/src/http/server.rs:714: Success

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@OlivierDehaene OR @Narsil

@michaelfeil
Copy link
Contributor

FYI, there is now https://huggingface.co/nomic-ai/modernbert-embed-base.
Let me know if you need GPU access @kozistr

@michaelfeil
Copy link
Contributor

FYI, running the nomic/modernbert-base model yields an error as the safetensors are not under model.embeddings.* but embeddings.*

@kozistr
Copy link
Contributor Author

kozistr commented Jan 6, 2025

FYI, there is now https://huggingface.co/nomic-ai/modernbert-embed-base. Let me know if you need GPU access @kozistr

thanks! I've just worked on supporting nomic-ai/modernbert-embed-base and it seems to be working well too. 3b20211

also appreciate your offer of the GPU support! currently, I'm kinda a lot on my plate so, I'll reach out later to you :) anyway, thanks again for your support

$ ./target/release/text-embeddings-router --model-id ./modernbert-embed-base --port 8888 --pooling mean --dtype float32
2025-01-06T03:09:26.039864Z  INFO text_embeddings_router: router/src/main.rs:175: Args { model_id: "./mod*******-*****-*ase", revision: None, tokenization_workers: None, dtype: Some(Float32), pooling: Some(Mean), max_concurrent_requests: 512, max_batch_tokens: 16384, max_batch_requests: None, max_client_batch_size: 32, auto_truncate: false, default_prompt_name: None, default_prompt: None, hf_api_token: None, hostname: "0.0.0.0", port: 8888, uds_path: "/tmp/text-embeddings-inference-server", huggingface_hub_cache: None, payload_limit: 2000000, api_key: None, json_output: false, otlp_endpoint: None, otlp_service_name: "text-embeddings-inference.server", cors_allow_origin: None }
2025-01-06T03:09:26.126234Z  INFO text_embeddings_router: router/src/lib.rs:188: Maximum number of tokens per request: 8192
2025-01-06T03:09:26.126419Z  INFO text_embeddings_core::tokenization: core/src/tokenization.rs:28: Starting 8 tokenization workers
2025-01-06T03:09:26.196076Z  INFO text_embeddings_router: router/src/lib.rs:230: Starting model backend
2025-01-06T03:09:26.196763Z  INFO text_embeddings_backend_candle: backends/candle/src/lib.rs:239: Starting ModernBert model on Cpu
2025-01-06T03:09:26.459153Z  WARN text_embeddings_router: router/src/lib.rs:258: Backend does not support a batch size > 4
2025-01-06T03:09:26.459182Z  WARN text_embeddings_router: router/src/lib.rs:259: forcing `max_batch_requests=4`
2025-01-06T03:09:26.460282Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1812: Starting HTTP server: 0.0.0.0:8888
2025-01-06T03:09:26.460306Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1813: Ready
2025-01-06T03:09:31.426262Z  INFO embed{total_time="121.542397ms" tokenization_time="356.3µs" queue_time="418.8µs" inference_time="120.695897ms"}: text_embeddings_router::http::server: router/src/http/server.rs:714: Success

@touhi99
Copy link

touhi99 commented Jan 9, 2025

is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ?

@kozistr
Copy link
Contributor Author

kozistr commented Jan 11, 2025

is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ?

It looks like it uses custom tokenizing logic that uses multiple tokenizers and determines one tokenizer on the fly, depending on the input text. the architecture in and of itself is supported, but it would be hard to use with TEI I guess.

@touhi99
Copy link

touhi99 commented Jan 14, 2025

is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ?

It looks like it uses custom tokenizing logic that uses multiple tokenizers and determines one tokenizer on the fly, depending on the input text. the architecture in and of itself is supported, but it would be hard to use with TEI I guess.

I have found another fine-tune from them which is specifically for German (DE), https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE-DE/blob/main/config.json

but i am having the issue as their config says pad_token_id null. I tried to follow through your implementation but this is where i stuck where the model is expecting a pad_token_id

@kozistr
Copy link
Contributor Author

kozistr commented Jan 14, 2025

is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ?

It looks like it uses custom tokenizing logic that uses multiple tokenizers and determines one tokenizer on the fly, depending on the input text. the architecture in and of itself is supported, but it would be hard to use with TEI I guess.

I have found another fine-tune from them which is specifically for German (DE), https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE-DE/blob/main/config.json

but i am having the issue as their config says pad_token_id null. I tried to follow through your implementation but this is where i stuck where the model is expecting a pad_token_id

it seems like it uses </s> as a pad token and the token id is 2. link.

  • bos_token_id: 1
  • eos_token_id: 2
  • pad_token_id: 2
  • cls_token_id: 0 // dummy
  • sep_token_id: 0 // dummy

You should fill in missing configs with proper values in config.json. you can check the full config from here

@touhi99
Copy link

touhi99 commented Jan 15, 2025

is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ?

It looks like it uses custom tokenizing logic that uses multiple tokenizers and determines one tokenizer on the fly, depending on the input text. the architecture in and of itself is supported, but it would be hard to use with TEI I guess.

I have found another fine-tune from them which is specifically for German (DE), https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE-DE/blob/main/config.json
but i am having the issue as their config says pad_token_id null. I tried to follow through your implementation but this is where i stuck where the model is expecting a pad_token_id

it seems like it uses </s> as a pad token and the token id is 2. link.

  • bos_token_id: 1
  • eos_token_id: 2
  • pad_token_id: 2
  • cls_token_id: 0 // dummy
  • sep_token_id: 0 // dummy

You should fill in missing configs with proper values in config.json. you can check the full config from here

Thank you. I was able to run nomicai/modernbert-base following your instruction. The other fine-tuned one i mentioned already had some changes as you suggested. but seems still struggling for longer text (more than 128 tokens). I wrote to them directly.

@kozistr
Copy link
Contributor Author

kozistr commented Jan 19, 2025

is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ?

It looks like it uses custom tokenizing logic that uses multiple tokenizers and determines one tokenizer on the fly, depending on the input text. the architecture in and of itself is supported, but it would be hard to use with TEI I guess.

I have found another fine-tune from them which is specifically for German (DE), https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE-DE/blob/main/config.json
but i am having the issue as their config says pad_token_id null. I tried to follow through your implementation but this is where i stuck where the model is expecting a pad_token_id

it seems like it uses </s> as a pad token and the token id is 2. link.

  • bos_token_id: 1
  • eos_token_id: 2
  • pad_token_id: 2
  • cls_token_id: 0 // dummy
  • sep_token_id: 0 // dummy

You should fill in missing configs with proper values in config.json. you can check the full config from here

Thank you. I was able to run nomicai/modernbert-base following your instruction. The other fine-tuned one i mentioned already had some changes as you suggested. but seems still struggling for longer text (more than 128 tokens). I wrote to them directly.

great to hear!

If you encounter an issue, index-select invalid index 128 with dim size 128, it's a bug in my ModernBert implementation, which is the rotary encoding part. I'm currently working on it, and I'll let you know when it's fixed!

--- updated

I just fixed the bug 63c4224, could you please test with the latest commit?

@touhi99
Copy link

touhi99 commented Jan 21, 2025

is it also supported in the same architecture https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE ?

It looks like it uses custom tokenizing logic that uses multiple tokenizers and determines one tokenizer on the fly, depending on the input text. the architecture in and of itself is supported, but it would be hard to use with TEI I guess.

I have found another fine-tune from them which is specifically for German (DE), https://huggingface.co/Parallia/Fairly-Multilingual-ModernBERT-Embed-BE-DE/blob/main/config.json
but i am having the issue as their config says pad_token_id null. I tried to follow through your implementation but this is where i stuck where the model is expecting a pad_token_id

it seems like it uses </s> as a pad token and the token id is 2. link.

  • bos_token_id: 1
  • eos_token_id: 2
  • pad_token_id: 2
  • cls_token_id: 0 // dummy
  • sep_token_id: 0 // dummy

You should fill in missing configs with proper values in config.json. you can check the full config from here

Thank you. I was able to run nomicai/modernbert-base following your instruction. The other fine-tuned one i mentioned already had some changes as you suggested. but seems still struggling for longer text (more than 128 tokens). I wrote to them directly.

great to hear!

If you encounter an issue, index-select invalid index 128 with dim size 128, it's a bug in my ModernBert implementation, which is the rotary encoding part. I'm currently working on it, and I'll let you know when it's fixed!

--- updated

I just fixed the bug 63c4224, could you please test with the latest commit?

Thanks a lot! it fixed the bug. I can confirm, I no longer see that issue and run long text embedding.

@xfalcox
Copy link

xfalcox commented Feb 11, 2025

Is this on the path of getting merged ?

@vrdn-23
Copy link

vrdn-23 commented Mar 20, 2025

@alvarobartt @regisss Any chance we can merge this soon?

Copy link
Member

@McPatate McPatate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've noticed with @Narsil a difference when using AlibabaNLP/gte-modernbert-base with sentence-transformers, the embeddings don't match

@@ -4,7 +4,7 @@ use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
bias: Option<Tensor>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps recreate a new LayerNorm class that has no bias rather than changing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second this. At first, I believed it'd be great to recreate a new LayerNorm struct after the maintainers confirm it regarding their preferences and usages.

if it's okay, I'll add a new LayerNorm named with LayerNormNoBias.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added 41852e4

Comment on lines +78 to +82
let mut new_qkv_shape = qkv.dims().to_vec();
new_qkv_shape.pop();
new_qkv_shape.push(self.num_attention_heads * 3);
new_qkv_shape.push(self.attention_head_size);
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to transpose in load?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, the variable qkv is determined on the fly, so it cannot be transposed in load().

Apart from this, maybe we could refactor this part more cleanly like below. how about this?

        let qkv = qkv
            .reshape((
                b,
                seq_len,
                3,
                self.num_attention_heads,
                self.attention_head_size,
            ))?
            .permute((2, 0, 3, 1, 4))?;

        let q = qkv.get(0)?;
        let k = qkv.get(1)?;
        let v = qkv.get(2)?;

Comment on lines +85 to +86
let query_layer = &qkv[0].contiguous()?;
let key_layer = &qkv[1].contiguous()?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contiguous copies the underlying tensor, could this also be handled in load?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +289 to +308
let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();

for use_local_attention in [true, false] {
let rope_theta = if use_local_attention {
config.local_rope_theta
} else {
config.global_rope_theta
};

let max_position_embeddings = if use_local_attention {
config.max_position_embeddings
} else {
config.local_attention
};

let inv_freqs = get_inv_freqs(rotary_dim, rope_theta as f32, vb.device(), None)?;

let (cos, sin) = get_cos_sin(max_position_embeddings, &inv_freqs, vb.dtype(), true)?;

rotary_cache.insert(use_local_attention, (cos, sin));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HashMap<bool, (Tensor, Tensor)> -> [(Tensor, Tensor); 2] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also agree tuple is enough to save that. I'll refactor this too. thanks!

Copy link
Contributor Author

@kozistr kozistr Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored dc3aa0a

While using a tuple can be a great option, on second thought, defining them with intuitive names could lead to more readable code, so I created a new struct RotarayEmbedding, and defined global rotary embed and local rotary embed each!

perhaps we can refactor another rotary usage with this struct.

-- updated 3/31

Due to some problems, I rolled back the RotaryEmbedding struct and defined global/local rotary cache with a tuple each. you can check here 43f2322!

&self.device,
)?;

let mut rotary_cache: HashMap<bool, (Tensor, Tensor)> = HashMap::new();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored dc3aa0a

#[test]
#[serial_test::serial]
fn test_mini_pooled_raw() -> Result<()> {
let model_root = download_artifacts("sentence-transformers/all-mpnet-base-v2", None)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be modernbert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching this. gonna fix this too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed 57a04fc

@kozistr
Copy link
Contributor Author

kozistr commented Mar 25, 2025

We've noticed with @Narsil a difference when using AlibabaNLP/gte-modernbert-base with sentence-transformers, the embeddings don't match

Thanks for your review! I'll look into the AlibabaNLP/gte-modernbert-base model too.

I'll work on your review and ping you afterward. Thank you for taking the time to review it!

Copy link
Member

@alvarobartt alvarobartt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @kozistr, just added some minor nits (I'll try to explore why it's not working on MPS) 🤗

Comment on lines 238 to 243
(Config::ModernBert(config), _) => {
tracing::info!("Starting ModernBert model on {:?}", device);
Ok(Box::new(
ModernBertModel::load(vb, &config, model_type).s()?,
))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that's not supported on Metal due to Metal strided to_dtype F16 U8 not implemented could you include a check to emit a BackendError if the device is Device::Metal until solved? Thanks in advance!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok it seems that some kernels are not still within the current candle version cc @Narsil, see huggingface/candle@6eea45a (in this case the cast_f16_u8_strided one is missing)

Copy link
Contributor Author

@kozistr kozistr Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this error. okay I'll exclude MPS device support until solved.

excluded d2233e5

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alvarobartt could you please test on the Metal device again with the latest commit? I refactored the get_local_attention_mask function, which may resolve the above issue F16 U8 not implemented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kozistr, so I disabled Device::Metal error and managed to run it with dtype=float32, but it failed with the default precision dtype=float16 due to dtype mismatch in add, lhs: F16, rhs: F32

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If enabling MPS support with the current candle version is too much of a hustle, we can maybe enforce the dtype to be float32 only, and raise an error otherwise. Anyway I expect most of the usage to be CPU or GPU, so let's just document that and move forward 🤗

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. something like the following

image

Copy link
Contributor Author

@kozistr kozistr Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found that fp16 with any devices got an error dtype mismatch in add, lhs: F16, rhs: F32 and it's because of attention_mask, which I mistakenly set f32 min even if the dtype is fp16.

From now on, fp16 with CPU works (== runs w/o error)!

$ text-embeddings-router --model-id ./gte-modernbert-base --pooling cls --port 8888 --dtype float16 --auto-truncate
[[0.013360826,-0.056719333,-0.016282361...,0.016844664,0.05989757,0.010225371]]

fixed ceccbca

Apart from this, I've tested with the AlibabaNLP/gte-modernbert-base model (fp16 weight), but failed to get identical results with sentence-transformers, while nomic-ai/modernbert-embed-base and answerdotai/ModernBert-base, large work.

So, how about disabling both MPS and fp16 support for now? (`if device == MPS or dtype == fp16),

or maybe if it works on MPS device based on ceccbca, we can just drop fp16 support for now.

Copy link
Contributor Author

@kozistr kozistr Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#459 (comment)

checked that nomic-ai/modernbert-embed-base, answerdotai/ModernBERT-base, and Alibaba-NLP/gte-modernbert-base are working

kozistr and others added 10 commits March 28, 2025 14:30
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
@Narsil
Copy link
Collaborator

Narsil commented Mar 28, 2025

I have differences with all models under tests, how do you check the model outputs ?

We've migrated to candle 0.8 on main, rebasing should be ok, but I can help with that.

@kozistr
Copy link
Contributor Author

kozistr commented Mar 31, 2025

I have differences with all models under tests, how do you check the model outputs ?

We've migrated to candle 0.8 on main, rebasing should be ok, but I can help with that.

hi. Sorry for the confusion. I made a mistake while refactoring the rotary embedding part dc3aa0a, so I've fixed the bug.

And I finally checked the output looks correct 43f2322. Please let me know if there's still an issue or anything!

  • nomic-ai/modernbert-embed-base
  • answerdotai/ModernBERT-base
  • Alibaba-NLP/gte-modernbert-base

Here's my code.

# sentence-transformers 4.0.1
# transformers 4.50.3
# tokenizers 0.21.1
# tested on CPU

sentences = [
    'What is Deep Learning?',
    'Deep Learning is...',
    'What is Deep Learning?',
]

model_id = 'answerdotai/ModernBERT-base'

# Sentence-Transformers
st_model = SentenceTransformer(model_id)
st_model._modules['1'].pooling_mode_mean_tokens = False
st_model._modules['1'].pooling_mode_cls_token = True
st_model._modules['1'].include_prompt = False
st_model.eval()

# Transformers
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ModernBertModel.from_pretrained(model_id)
model.eval()

with torch.inference_mode():
    st_results = st_model.encode(sentences, normalize_embeddings=True)

tokens = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.inference_mode():
    embeddings = model(**tokens)[0]
    tf_results = normalize(embeddings[:, 0, :], p=2.0).numpy()

# $ text-embeddings-router --model-id ./ModernBERT-base/ --pooling cls --port 8080 --dtype float32 --auto-truncate
tei_results = np.asarray(
    requests.post(
        'http://127.0.0.1:8080/embed',
        data=json.dumps({'inputs': sentences}),
        headers={'Content-type': 'application/json'},
    ).json()
)

# st_results
array([[-0.00473989, -0.01706669, -0.01927578, ..., -0.02717395,
        -0.00850304, -0.01998385],
       [ 0.02099864, -0.02351814, -0.02991876, ..., -0.02211199,
         0.02051133, -0.02203598],
       [-0.00473989, -0.01706669, -0.01927578, ..., -0.02717395,
        -0.00850304, -0.01998385]], dtype=float32)
# tf results
array([[-0.0047399 , -0.01706671, -0.01927581, ..., -0.02717397,
        -0.00850303, -0.01998387],
       [ 0.02099863, -0.0235182 , -0.02991882, ..., -0.02211192,
         0.02051148, -0.02203591],
       [-0.00473991, -0.0170667 , -0.01927581, ..., -0.02717397,
        -0.00850306, -0.0199839 ]], dtype=float32)
# tei results
array([[-0.00472688, -0.01706322, -0.01927264, ..., -0.02720699,
        -0.00847499, -0.02000728],
       [ 0.02097508, -0.02304459, -0.02991414, ..., -0.02252471,
         0.01963887, -0.0224616 ],
       [-0.00472688, -0.01706322, -0.01927264, ..., -0.02720699,
        -0.00847499, -0.02000728]])

We've migrated to candle 0.8 on main, rebasing should be ok, but I can help with that.

It'd be great to help with rebasing! Could you please help with that?

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey thanks a lot for this contribution.

I still see wide differences between sentence-transformers and this implementation here.
I'm not sure if I'm making any mistakes in my comparison but I expect

# Requires transformers>=4.48.0

import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

input_texts = [
    "what is the capital of China?",
    "how to implement quick sort in python?",
    "Beijing",
    "sorting algorithms"
]

model_path = "Alibaba-NLP/gte-modernbert-base"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)

# Tokenize the input texts
batch_dict = tokenizer(input_texts, max_length=8192, padding=True, truncation=True, return_tensors='pt')

outputs = model(**batch_dict)
embeddings = outputs.last_hidden_state[:, 0]
 
# (Optionally) normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)

To provide the correct embeddings (there is no normalization normally I'm providing the normalized output in case).

Also this implementation is lacking the Flash version.

However, in the interest of getting things done, I will merge this already, and attempt to fix the implementation in a follow-up in order to get things moving on this front.

Thanks a lot for this PR !

@Narsil Narsil merged commit 5104236 into huggingface:main Apr 4, 2025
@tomaarsen
Copy link
Member

@Narsil I modified & ran one of the above scripts, and I'm able to get matching results:

Modified script for gte-modernbert-base
import json
import numpy as np
import requests
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, ModernBertModel
from torch.nn.functional import normalize

sentences = [
    'What is Deep Learning?',
    'Deep Learning is...',
    'What is Deep Learning?',
]

model_id = 'Alibaba-NLP/gte-modernbert-base'

# Sentence-Transformers
st_model = SentenceTransformer(model_id)
st_model.eval()

# Transformers
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ModernBertModel.from_pretrained(model_id)
model.eval()

with torch.inference_mode():
    st_results = st_model.encode(sentences, normalize_embeddings=True)

tokens = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
with torch.inference_mode():
    embeddings = model(**tokens)[0]
    tf_results = normalize(embeddings[:, 0, :], p=2.0).numpy()

# $ text-embeddings-router --model-id Alibaba-NLP/gte-modernbert-base --pooling cls --port 8080 --dtype float32 --auto-truncate
tei_results = np.asarray(
    requests.post(
        'http://127.0.0.1:8080/embed',
        data=json.dumps({'inputs': sentences}),
        headers={'Content-type': 'application/json'},
    ).json()
)

print(st_results)
# [[ 0.01344569 -0.05683083 -0.01630153 ...  0.01694537  0.05998897
#    0.01016068]
#  [ 0.01036755 -0.06500905 -0.03710588 ... -0.00754553  0.05179105
#    0.00935954]
#  [ 0.01344569 -0.05683083 -0.01630153 ...  0.01694537  0.05998897
#    0.01016068]]
print(tf_results)
# [[ 0.01344565 -0.05683074 -0.01630154 ...  0.01694529  0.05998905
#    0.01016066]
#  [ 0.01036758 -0.06500904 -0.03710586 ... -0.0075455   0.05179105
#    0.00935955]
#  [ 0.0134457  -0.0568308  -0.01630155 ...  0.01694538  0.05998901
#    0.01016067]]
print(tei_results)
# [[ 0.01344578 -0.05683083 -0.01630141 ...  0.01694547  0.05998898
#    0.01016063]
#  [ 0.01036761 -0.06500905 -0.03710582 ... -0.00754556  0.05179107
#    0.00935959]
#  [ 0.01344578 -0.05683083 -0.01630141 ...  0.01694547  0.05998898
#    0.01016063]]
  • Tom Aarsen

@Narsil
Copy link
Collaborator

Narsil commented Apr 4, 2025

I also have everything matching after merging. I have no clue what I botched in my testing. I manually rebased so maybe I screwed something there.

It wasn't working on Cuda because of abs() in the local mask (there's no such op in candle) and the Flash Attention backend was never called.

I'm following this up, but at least the simple implementation does work, thanks a lot !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

support for answerdotai/ModernBERT-base
9 participants