-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmake_msa_seq_feats.py
executable file
·137 lines (116 loc) · 5.11 KB
/
make_msa_seq_feats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Mapping, Optional, Sequence
from alphafold.common import residue_constants
from alphafold.data import parsers
import numpy as np
import argparse
import pickle
import sys
import pdb
# Internal import (7716).
parser = argparse.ArgumentParser(description = """Builds the input features for training the structure prediction model.""")
parser.add_argument('--input_fasta_path', nargs=1, type= str, default=sys.stdin, help = 'Path to fasta.')
parser.add_argument('--input_msas', nargs=1, type= str, default=sys.stdin, help = 'Path to MSAs. Separated by comma.')
parser.add_argument('--outdir', nargs=1, type= str, default=sys.stdin, help = 'Path to output directory. Include /in end')
FeatureDict = Mapping[str, np.ndarray]
##############FUNCTIONS##############
def make_sequence_features(
sequence: str, description: str, num_res: int) -> FeatureDict:
"""Constructs a feature dict of sequence features."""
features = {}
features['aatype'] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True)
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array([description.encode('utf-8')],
dtype=np.object_)
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
return features
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError('At least one MSA must be provided.')
int_msa = []
deletion_matrix = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
num_res = len(msas[0][0])
num_alignments = len(int_msa)
features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array(
[num_alignments] * num_res, dtype=np.int32)
return features
def process(input_fasta_path: str, input_msas: list) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_desc = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_desc[0]
num_res = len(input_sequence)
parsed_msas = []
parsed_delmat = []
for custom_msa in input_msas:
msa = ''.join([line for line in open(custom_msa)])
if custom_msa[-3:] == 'sto':
parsed_msa, parsed_deletion_matrix, _ = parsers.parse_stockholm(msa)
elif custom_msa[-3:] == 'a3m':
parsed_msa, parsed_deletion_matrix = parsers.parse_a3m(msa)
else: raise TypeError('Unknown format for input MSA, please make sure '
'the MSA files you provide terminates with (and '
'are formatted as) .sto or .a3m')
parsed_msas.append(parsed_msa)
parsed_delmat.append(parsed_deletion_matrix)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res)
msa_features = make_msa_features(
msas=parsed_msas, deletion_matrices=parsed_delmat)
return {**sequence_features, **msa_features}
##################MAIN#######################
#Parse args
args = parser.parse_args()
#Data
input_fasta_path = args.input_fasta_path[0]
input_msas = args.input_msas[0].split(',')
outdir = args.outdir[0]
#Get feats
feature_dict = process(input_fasta_path, input_msas)
#Write out features as a pickled dictionary.
features_output_path = os.path.join(outdir, 'msa_features.pkl')
with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)
print('Saved features to',features_output_path)