Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ESMFold #19977

Merged
merged 46 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5fdd272
initial commit
Rocketknight1 Oct 20, 2022
2996fce
First draft that gets outputs without crashing!
Rocketknight1 Oct 25, 2022
74229d2
Add all the ported openfold dependencies
Rocketknight1 Oct 26, 2022
4577374
testing
Rocketknight1 Oct 27, 2022
f1641cc
Restructure config files for ESMFold
Rocketknight1 Oct 28, 2022
65569f0
Debugging to find output discrepancies
Rocketknight1 Oct 31, 2022
26f3e25
Mainly style
sgugger Oct 31, 2022
988752a
Make model runnable without extra deps
sgugger Oct 31, 2022
92a5887
Remove utils and merge them to the modeling file
sgugger Oct 31, 2022
dd1f480
Use correct gelu and remove some debug prints
Rocketknight1 Oct 31, 2022
6e31a39
More cleanup
sgugger Oct 31, 2022
0e69779
Update esm docs
Rocketknight1 Oct 31, 2022
010fab6
Update conversion script to support ESMFold properly
Rocketknight1 Oct 31, 2022
94d0663
Port some top-level changes from ESMFold repo
Rocketknight1 Oct 31, 2022
728f21f
Expand EsmFold docstrings
Rocketknight1 Oct 31, 2022
28839b7
Make attention_mask optional (default to all 1s)
Rocketknight1 Oct 31, 2022
48c97f4
Add inference test for ESMFold
Rocketknight1 Oct 31, 2022
2422d11
Use config and not n kwargs
sgugger Oct 31, 2022
fcbf85d
Merge branch 'add_esmfold' of github.com:huggingface/transformers int…
sgugger Oct 31, 2022
e7bf6a5
Add modeling output class
sgugger Oct 31, 2022
69d5169
Remove einops
sgugger Oct 31, 2022
5770297
Remove chunking in ESM FFN
Rocketknight1 Oct 31, 2022
f8a9945
Update tests for ESMFold
Rocketknight1 Oct 31, 2022
6ab675c
Quality
sgugger Oct 31, 2022
88757cb
Merge branch 'add_esmfold' of github.com:huggingface/transformers int…
sgugger Oct 31, 2022
1ead4c0
REpo consistency
sgugger Oct 31, 2022
cff0224
Remove tree dependency from ESMFold
Rocketknight1 Oct 31, 2022
b83a592
Merge remote-tracking branch 'origin/add_esmfold' into add_esmfold
Rocketknight1 Oct 31, 2022
5b0fbae
make fixup
Rocketknight1 Oct 31, 2022
bac51f2
Add an error in case my structure map function breaks later
Rocketknight1 Oct 31, 2022
61d6581
Remove needless code
sgugger Oct 31, 2022
44ed50f
Fix merge conflicts
sgugger Oct 31, 2022
f5e7575
Stop auto-casting the LM to float16 so CPU tests pass
Rocketknight1 Oct 31, 2022
8bbf375
Stop auto-casting the LM to float16 so CPU tests pass
Rocketknight1 Oct 31, 2022
7632a12
Final test updates
Rocketknight1 Oct 31, 2022
d14fddb
Split test file
sgugger Oct 31, 2022
a91465b
Copyright and quality
sgugger Oct 31, 2022
e45e6fc
Unpin PyTorch to see built doc
sgugger Oct 31, 2022
7bebbbf
Fix config file to_dict() method
Rocketknight1 Oct 31, 2022
60d681c
Add some docstrings to the output
Rocketknight1 Oct 31, 2022
e2d9ff7
Skip TF checkpoint tests for ESM until we reupload those
Rocketknight1 Oct 31, 2022
e765f59
make fixup
Rocketknight1 Oct 31, 2022
9c9f9fa
More docstrings
Rocketknight1 Nov 1, 2022
fd455c4
Unpin to get even with main
sgugger Nov 1, 2022
020a302
Merge branch 'add_esmfold' of github.com:huggingface/transformers int…
sgugger Nov 1, 2022
f6f4a2c
Flag example to write
sgugger Nov 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove needless code
  • Loading branch information
sgugger committed Oct 31, 2022
commit 61d658101f33bf2c97ae6c272c5895d5fe4ccc10
6 changes: 1 addition & 5 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
PreTrainedModel,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_esm import EsmConfig

Expand Down
18 changes: 10 additions & 8 deletions src/transformers/models/esm/modeling_esmfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@
)
from .configuration_esm import EsmConfig
from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
from .openfold_data_transforms import make_atom14_masks
from .openfold_np import residue_constants
from .openfold_np.protein import Protein as OFProtein
from .openfold_np.protein import to_pdb
from .openfold_utils.chunk_utils import chunk_layer
from .openfold_utils.feats import (
from .openfold_utils import (
OFProtein,
Rigid,
Rotation,
atom14_to_atom37,
chunk_layer,
compute_predicted_aligned_error,
compute_tm,
frames_and_literature_positions_to_atom14_pos,
make_atom14_masks,
residue_constants,
to_pdb,
torsion_angles_to_frames,
)
from .openfold_utils.loss import compute_predicted_aligned_error, compute_tm
from .openfold_utils.rigid_utils import Rigid, Rotation


logger = logging.get_logger(__name__)
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions src/transformers/models/esm/openfold_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# flake8: noqa
from .chunk_utils import chunk_layer
from .data_transforms import make_atom14_masks
from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames
from .loss import compute_predicted_aligned_error, compute_tm
from .protein import Protein as OFProtein
from .protein import to_pdb
from .rigid_utils import Rigid, Rotation
32 changes: 0 additions & 32 deletions src/transformers/models/esm/openfold_utils/argparse.py

This file was deleted.

17 changes: 0 additions & 17 deletions src/transformers/models/esm/openfold_utils/callbacks.py

This file was deleted.

94 changes: 0 additions & 94 deletions src/transformers/models/esm/openfold_utils/checkpointing.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from functools import reduce, wraps
from operator import add

import numpy as np
import torch

from .openfold_np import residue_constants as rc
from .openfold_utils.tensor_utils import tensor_tree_map, tree_map
from . import residue_constants as rc
from .tensor_utils import tensor_tree_map, tree_map


def make_atom14_masks(protein):
Expand Down

This file was deleted.

7 changes: 1 addition & 6 deletions src/transformers/models/esm/openfold_utils/feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Dict

import numpy as np
import torch
import torch.nn as nn

from ..openfold_np import protein
from ..openfold_np import residue_constants as rc
from . import residue_constants as rc
from .rigid_utils import Rigid, Rotation
from .tensor_utils import batched_gather, one_hot, tensor_tree_map, tree_map

Expand Down
Loading