Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ESMFold #19977

Merged
merged 46 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5fdd272
initial commit
Rocketknight1 Oct 20, 2022
2996fce
First draft that gets outputs without crashing!
Rocketknight1 Oct 25, 2022
74229d2
Add all the ported openfold dependencies
Rocketknight1 Oct 26, 2022
4577374
testing
Rocketknight1 Oct 27, 2022
f1641cc
Restructure config files for ESMFold
Rocketknight1 Oct 28, 2022
65569f0
Debugging to find output discrepancies
Rocketknight1 Oct 31, 2022
26f3e25
Mainly style
sgugger Oct 31, 2022
988752a
Make model runnable without extra deps
sgugger Oct 31, 2022
92a5887
Remove utils and merge them to the modeling file
sgugger Oct 31, 2022
dd1f480
Use correct gelu and remove some debug prints
Rocketknight1 Oct 31, 2022
6e31a39
More cleanup
sgugger Oct 31, 2022
0e69779
Update esm docs
Rocketknight1 Oct 31, 2022
010fab6
Update conversion script to support ESMFold properly
Rocketknight1 Oct 31, 2022
94d0663
Port some top-level changes from ESMFold repo
Rocketknight1 Oct 31, 2022
728f21f
Expand EsmFold docstrings
Rocketknight1 Oct 31, 2022
28839b7
Make attention_mask optional (default to all 1s)
Rocketknight1 Oct 31, 2022
48c97f4
Add inference test for ESMFold
Rocketknight1 Oct 31, 2022
2422d11
Use config and not n kwargs
sgugger Oct 31, 2022
fcbf85d
Merge branch 'add_esmfold' of github.com:huggingface/transformers int…
sgugger Oct 31, 2022
e7bf6a5
Add modeling output class
sgugger Oct 31, 2022
69d5169
Remove einops
sgugger Oct 31, 2022
5770297
Remove chunking in ESM FFN
Rocketknight1 Oct 31, 2022
f8a9945
Update tests for ESMFold
Rocketknight1 Oct 31, 2022
6ab675c
Quality
sgugger Oct 31, 2022
88757cb
Merge branch 'add_esmfold' of github.com:huggingface/transformers int…
sgugger Oct 31, 2022
1ead4c0
REpo consistency
sgugger Oct 31, 2022
cff0224
Remove tree dependency from ESMFold
Rocketknight1 Oct 31, 2022
b83a592
Merge remote-tracking branch 'origin/add_esmfold' into add_esmfold
Rocketknight1 Oct 31, 2022
5b0fbae
make fixup
Rocketknight1 Oct 31, 2022
bac51f2
Add an error in case my structure map function breaks later
Rocketknight1 Oct 31, 2022
61d6581
Remove needless code
sgugger Oct 31, 2022
44ed50f
Fix merge conflicts
sgugger Oct 31, 2022
f5e7575
Stop auto-casting the LM to float16 so CPU tests pass
Rocketknight1 Oct 31, 2022
8bbf375
Stop auto-casting the LM to float16 so CPU tests pass
Rocketknight1 Oct 31, 2022
7632a12
Final test updates
Rocketknight1 Oct 31, 2022
d14fddb
Split test file
sgugger Oct 31, 2022
a91465b
Copyright and quality
sgugger Oct 31, 2022
e45e6fc
Unpin PyTorch to see built doc
sgugger Oct 31, 2022
7bebbbf
Fix config file to_dict() method
Rocketknight1 Oct 31, 2022
60d681c
Add some docstrings to the output
Rocketknight1 Oct 31, 2022
e2d9ff7
Skip TF checkpoint tests for ESM until we reupload those
Rocketknight1 Oct 31, 2022
e765f59
make fixup
Rocketknight1 Oct 31, 2022
9c9f9fa
More docstrings
Rocketknight1 Nov 1, 2022
fd455c4
Unpin to get even with main
sgugger Nov 1, 2022
020a302
Merge branch 'add_esmfold' of github.com:huggingface/transformers int…
sgugger Nov 1, 2022
f6f4a2c
Flag example to write
sgugger Nov 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Debugging to find output discrepancies
  • Loading branch information
Rocketknight1 authored and sgugger committed Oct 31, 2022
commit 65569f03ad7be280370d9c9b5551c2b20cd57e14
80 changes: 58 additions & 22 deletions src/transformers/models/esm/convert_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
from transformers.models.esm.tokenization_esm import EsmTokenizer
from transformers.utils import logging
from esm.esmfold.v1.pretrained import esmfold_3B_v1
from esm.esmfold.v1.pretrained import esmfold_v1


logging.set_verbosity_info()
Expand All @@ -62,9 +62,16 @@
"esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
"esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
"esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
"esmfold_v1": esmfold_3B_v1,
"esmfold_v1": esmfold_v1,
}

def transfer_and_check_weights(original_module, our_module):
status = our_module.load_state_dict(original_module.state_dict())
if status.missing_keys:
raise ValueError(f"Missing keys: {status.missing_keys}")
if status.unexpected_keys:
raise ValueError(f"Unexpected keys: {status.unexpected_keys}")


def convert_esm_checkpoint_to_pytorch(
model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
Expand Down Expand Up @@ -231,9 +238,14 @@ def convert_esm_checkpoint_to_pytorch(
# end of layer

if is_folding_model:
breakpoint() # Can I just copy state dicts or something from here?
model.esm_s_combine.weight = esm.esm_s_combine.weight

model.esm_s_combine.data = esm.esm_s_combine.data
transfer_and_check_weights(esm.embedding, model.embedding)
transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
transfer_and_check_weights(esm.trunk, model.trunk)
transfer_and_check_weights(esm.distogram_head, model.distogram_head)
transfer_and_check_weights(esm.ptm_head, model.ptm_head)
transfer_and_check_weights(esm.lm_head, model.lm_head)
transfer_and_check_weights(esm.lddt_head, model.lddt_head)

elif classification_head:
model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
Expand All @@ -247,45 +259,64 @@ def convert_esm_checkpoint_to_pytorch(
model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
model.lm_head.decoder.weight = esm.lm_head.weight
model.lm_head.decoder.bias = esm.lm_head.bias
model.lm_head.bias = esm.lm_head.bias

# Let's check that we get the same results.
batch_converter = alphabet.get_batch_converter()

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
if is_folding_model:
# Folding models aren't trained on masked inputs and don't like mask tokens.
sample_data = SAMPLE_DATA[:2]
else:
sample_data = SAMPLE_DATA

batch_labels, batch_strs, batch_tokens = batch_converter(SAMPLE_DATA)

batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
# Prepare tokenizer and make sure it matches
with TemporaryDirectory() as tempdir:
vocab = "\n".join(alphabet.all_toks)
vocab_file = Path(tempdir) / "vocab.txt"
vocab_file.write_text(vocab)
hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))

hf_tokens = hf_tokenizer([row[1] for row in SAMPLE_DATA], return_tensors="pt", padding=True)
hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
success = torch.all(hf_tokens["input_ids"] == batch_tokens)
print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
if not success:
raise Exception("Tokenization does not match!")

with torch.no_grad():
our_output = model(**hf_tokens, output_hidden_states=True)
our_hidden_states = our_output['hidden_states']
our_output = our_output["logits"]
if classification_head:
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
if is_folding_model:
# Let's test the model in parts
# ESMFold always converts the ESM stem to float16, which requires float16 ops
# that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
# ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
# original and the converted model on the GPU at the same time.
compare_attention_layers(model, esm)
our_output = model.cuda()(input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda())
their_output = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda())
their_output_again = esm.cuda()(hf_tokens["input_ids"].cuda(), hf_tokens["attention_mask"].cuda())
else:
their_output = esm(batch_tokens, repr_layers=list(range(999)))
their_hidden_states = their_output['representations']
their_output = their_output["logits"]
our_output = model(**hf_tokens, output_hidden_states=True)
our_hidden_states = our_output['hidden_states']
our_output = our_output["logits"]
if classification_head:
their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
else:
their_output = esm(hf_tokens['input_ids'], repr_layers=list(range(999)))
their_hidden_states = their_output['representations']
their_output = their_output["logits"]

if is_folding_model:

print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
success = torch.allclose(our_output, their_output, atol=3e-4)
print("Do both models output the same tensors?", "🔥" if success else "💩")
breakpoint()
print()
else:
print(our_output.shape, their_output.shape)
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
success = torch.allclose(our_output, their_output, atol=3e-4)
print("Do both models output the same tensors?", "🔥" if success else "💩")

if not success:
raise Exception("Something went wRoNg")
Expand All @@ -294,6 +325,11 @@ def convert_esm_checkpoint_to_pytorch(
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)

test_reload = EsmForMaskedLM.from_pretrained(pytorch_dump_folder_path)
test_out = test_reload(**hf_tokens)

breakpoint()

print(f"Saving tokenizer to {pytorch_dump_folder_path}")
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def forward(
)
attention_output = self_attention_outputs[0]


# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
Expand Down Expand Up @@ -505,7 +506,7 @@ def forward(
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)

print(outputs)
return outputs

def feed_forward_chunk(self, attention_output):
Expand Down
28 changes: 13 additions & 15 deletions src/transformers/models/esm/modeling_esmfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Tuple, List, Optional, Union, Dict
from .modeling_esm import EsmPreTrainedModel, EsmModel
from .modeling_esm import EsmPreTrainedModel, EsmModel, EsmForMaskedLM

import torch
import torch.nn as nn
Expand Down Expand Up @@ -86,8 +86,6 @@ def __init__(self, config=None, **kwargs):
if self.config.esmfold_config.embed_aa:
self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)

# Matt: We want our config objects to be JSON serializable, which means avoiding having dataclasses
# as keys.
trunk_cfg_dict = self.config.esmfold_config.trunk

self.trunk = FoldingTrunk(trunk_cfg_dict)
Expand All @@ -102,7 +100,6 @@ def __init__(self, config=None, **kwargs):
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
)
# cls eos padding mask

@staticmethod
def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
Expand All @@ -112,26 +109,26 @@ def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:

def forward(
self,
input: torch.Tensor,
mask: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
mask_aa: bool = False,
residx: Optional[torch.Tensor] = None,
masking_pattern: Optional[torch.Tensor] = None,
):
cfg = self.config.esmfold_config

aa = input # B x L
aa = input_ids # B x L
B = aa.shape[0]
L = aa.shape[1]
device = input.device
device = input_ids.device
if residx is None:
residx = torch.arange(L, device=device).expand_as(input)
residx = torch.arange(L, device=device).expand_as(input_ids)

# === ESM ===
esmaa = self.af2_idx_to_esm_idx(aa, mask)
esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)

if (self.training or mask_aa) and masking_pattern is not None:
masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, mask, masking_pattern)
masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
else:
masked_aa = aa
mlm_targets = None
Expand Down Expand Up @@ -165,12 +162,12 @@ def forward(
if cfg.use_esm_attn_map:
s_z_0 = self.esm_z_mlp(esm_z)
else:
s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk_pairwise_state_dim)
s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)

if self.config.esmfold_config.embed_aa:
s_s_0 += self.embedding(masked_aa)

structure: dict = self.trunk(s_s_0, s_z_0, aa, residx, mask)
structure: dict = self.trunk(s_s_0, s_z_0, aa, residx, attention_mask)
# Documenting what we expect:
structure = {
k: v
Expand Down Expand Up @@ -209,7 +206,7 @@ def forward(
"atom14_atom_exists",
"atom37_atom_exists",
]:
structure[k] *= mask.unsqueeze(-1)
structure[k] *= attention_mask.unsqueeze(-1)
structure["residue_index"] = residx

lddt_head = self.lddt_head(structure["states"]).reshape(
Expand Down Expand Up @@ -259,7 +256,8 @@ def compute_language_model_representations(
# _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
# Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
# esm_z is always None
esm_hidden_states = self.esm(esmaa, output_hidden_states=True)["hidden_states"]
esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
print(esm_hidden_states[0])
esm_s = torch.stack(esm_hidden_states, dim=2)
esm_z = None

Expand Down