|
30 | 30 | register_input_transform, |
31 | 31 | ) |
32 | 32 | import os |
33 | | -if os.environ.get("AF3_OPTIMIZATIONS_MODE") == "baseline": |
34 | | - from megafold.baseline import ( |
35 | | - AdaptiveLayerNorm, |
36 | | - AttentionPairBias, |
37 | | - CentreRandomAugmentation, |
38 | | - ComputeModelSelectionScore, |
39 | | - ComputeRankingScore, |
40 | | - ConditionWrapper, |
41 | | - ConfidenceHead, |
42 | | - ConfidenceHeadLogits, |
43 | | - DiffusionModule, |
44 | | - DiffusionTransformer, |
45 | | - DistogramHead, |
46 | | - ElucidatedAtomDiffusion, |
47 | | - InputFeatureEmbedder, |
48 | | - MSAModule, |
49 | | - MSAPairWeightedAveraging, |
50 | | - MultiChainPermutationAlignment, |
51 | | - MegaFold, |
52 | | - MegaFoldWithHubMixin, |
53 | | - OuterProductMean, |
54 | | - PairformerStack, |
55 | | - PreLayerNorm, |
56 | | - RelativePositionEncoding, |
57 | | - TemplateEmbedder, |
58 | | - Transition, |
59 | | - TriangleAttention, |
60 | | - TriangleMultiplication, |
61 | | - ) |
62 | | -elif os.environ.get("AF3_OPTIMIZATIONS_MODE") == "layernormlinear": |
63 | | - from megafold.megafold_layernormlinear import ( |
64 | | - AdaptiveLayerNorm, |
65 | | - AttentionPairBias, |
66 | | - CentreRandomAugmentation, |
67 | | - ComputeModelSelectionScore, |
68 | | - ComputeRankingScore, |
69 | | - ConditionWrapper, |
70 | | - ConfidenceHead, |
71 | | - ConfidenceHeadLogits, |
72 | | - DiffusionModule, |
73 | | - DiffusionTransformer, |
74 | | - DistogramHead, |
75 | | - ElucidatedAtomDiffusion, |
76 | | - InputFeatureEmbedder, |
77 | | - MSAModule, |
78 | | - MSAPairWeightedAveraging, |
79 | | - MultiChainPermutationAlignment, |
80 | | - MegaFold, |
81 | | - MegaFoldWithHubMixin, |
82 | | - OuterProductMean, |
83 | | - PairformerStack, |
84 | | - PreLayerNorm, |
85 | | - RelativePositionEncoding, |
86 | | - TemplateEmbedder, |
87 | | - Transition, |
88 | | - TriangleAttention, |
89 | | - TriangleMultiplication, |
90 | | - ) |
91 | | -else: |
92 | | - from megafold.model.megafold import ( |
93 | | - AdaptiveLayerNorm, |
94 | | - AttentionPairBias, |
95 | | - CentreRandomAugmentation, |
96 | | - ComputeModelSelectionScore, |
97 | | - ComputeRankingScore, |
98 | | - ConditionWrapper, |
99 | | - ConfidenceHead, |
100 | | - ConfidenceHeadLogits, |
101 | | - DiffusionModule, |
102 | | - DiffusionTransformer, |
103 | | - DistogramHead, |
104 | | - ElucidatedAtomDiffusion, |
105 | | - InputFeatureEmbedder, |
106 | | - MSAModule, |
107 | | - MSAPairWeightedAveraging, |
108 | | - MultiChainPermutationAlignment, |
109 | | - MegaFold, |
110 | | - MegaFoldWithHubMixin, |
111 | | - OuterProductMean, |
112 | | - PairformerStack, |
113 | | - PreLayerNorm, |
114 | | - RelativePositionEncoding, |
115 | | - TemplateEmbedder, |
116 | | - Transition, |
117 | | - TriangleAttention, |
118 | | - TriangleMultiplication, |
119 | | - ) |
120 | 33 |
|
| 34 | +from megafold.model.megafold import ( |
| 35 | + AdaptiveLayerNorm, |
| 36 | + AttentionPairBias, |
| 37 | + CentreRandomAugmentation, |
| 38 | + ComputeModelSelectionScore, |
| 39 | + ComputeRankingScore, |
| 40 | + ConditionWrapper, |
| 41 | + ConfidenceHead, |
| 42 | + ConfidenceHeadLogits, |
| 43 | + DiffusionModule, |
| 44 | + DiffusionTransformer, |
| 45 | + DistogramHead, |
| 46 | + ElucidatedAtomDiffusion, |
| 47 | + InputFeatureEmbedder, |
| 48 | + MSAModule, |
| 49 | + MSAPairWeightedAveraging, |
| 50 | + MultiChainPermutationAlignment, |
| 51 | + MegaFold, |
| 52 | + MegaFoldWithHubMixin, |
| 53 | + OuterProductMean, |
| 54 | + PairformerStack, |
| 55 | + PreLayerNorm, |
| 56 | + RelativePositionEncoding, |
| 57 | + TemplateEmbedder, |
| 58 | + Transition, |
| 59 | + TriangleAttention, |
| 60 | + TriangleMultiplication, |
| 61 | +) |
121 | 62 | from megafold.trainer import DataLoader, Trainer |
122 | 63 | from megafold.utils.model_utils import ( |
123 | 64 | ComputeAlignmentError, |
|
0 commit comments