Skip to content

Replication of results with the original model #5

Closed
@pengzhangzhi

Description

Hi,

Report a significant discrepancy between the outputs of your ESM2 implementation (loaded using esme library) and the Hugging Face implementation (facebook/esm2_t6_8M_UR50D). While comparing the logits and representations from both models, I observed sufficiently large differences, which I believe warrant further investigation.

•	Logits:
•	Max Absolute Difference: 1.2461
•	Mean Absolute Difference: 0.1458

•	Representations:
•	Max Absolute Difference: 0.1914
•	Mean Absolute Difference: 0.0226

Is this expected? Or something seems wrong

Below is the code to reproduce the benchmarking

import torch
from esme import ESM2
from esme.alphabet import tokenize
from transformers import EsmTokenizer, EsmForMaskedLM

# ------------------------- Documentation -------------------------
# This script compares the outputs of two implementations of the ESM2 model:
# 1. A custom implementation (loaded using ESM2 and esme.alphabet)
# 2. A pre-trained Hugging Face implementation (using transformers library)
# The comparison focuses on:
# - Logits (model output before applying softmax)
# - Representations (final hidden states of the model)
# It computes and prints the absolute maximum and mean differences between the outputs
# for debugging purposes.

# ---------------------- Load Models and Tokenizer ----------------------
# Load the tokenizer and Hugging Face pre-trained ESM2 model
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Set the device to GPU (if available) or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
esm = esm.to(device)
esm.eval()  # Set the model to evaluation mode

# Load the custom ESM2 model
model = ESM2.from_pretrained("8M.safetensors", device=0)

# ------------------------- Prepare Input Tokens -------------------------
# Define input protein sequences for comparison
sequences = ['MEES', 'M']  # Example sequences
tokens = tokenize(sequences).to(0)  # Tokenize using custom tokenizer

# ---------------------- Custom Model Forward Pass ----------------------
# Generate logits and representations using the custom model
logits_custom = model(tokens).float()  # Model logits
repr_custom = model.forward_representation(tokens).float()  # Model representations

# ---------------------- Hugging Face Model Forward Pass ----------------------
# Generate logits and representations using the Hugging Face model
attention_mask = tokens != tokenizer.pad_token_id  # Attention mask for valid tokens
esm_output = esm(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=True)
logits_esm = esm_output.logits.float()  # Model logits
repr_esm = esm_output.hidden_states[-1].float()  # Last hidden state as representation

# ------------------------- Mask Padding Tokens -------------------------
# Mask out padding tokens for comparison
logits_custom = logits_custom[attention_mask]
logits_esm = logits_esm[attention_mask]
repr_custom = repr_custom[attention_mask]
repr_esm = repr_esm[attention_mask]

# ------------------------- Compute Differences -------------------------
# Compute absolute differences between logits
logit_diff_max = (logits_custom - logits_esm).abs().max()
logit_diff_mean = (logits_custom - logits_esm).abs().mean()

# Compute absolute differences between representations
repr_diff_max = (repr_custom - repr_esm).abs().max()
repr_diff_mean = (repr_custom - repr_esm).abs().mean()

# ------------------------- Print Debug Information -------------------------
# Print logit comparison results
print("Logit Comparison:")
print(f"Max Absolute Difference: {logit_diff_max}")
print(f"Mean Absolute Difference: {logit_diff_mean}")
print(f"Logit Difference Shape: {logits_custom.shape}")

# Print representation comparison results
print("\nRepresentation Comparison:")
print(f"Max Absolute Difference: {repr_diff_max}")
print(f"Mean Absolute Difference: {repr_diff_mean}")
print(f"Representation Difference Shape: {repr_custom.shape}")
eddings.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions