Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for protein sequence matching #15

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions CodonTransformer/CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,26 +176,30 @@

# Handle ambiguous amino acids based on the specified behavior
config = ProteinConfig()
ambiguous_aminoacid_map_override = config.get('ambiguous_aminoacid_map_override')
ambiguous_aminoacid_behavior = config.get('ambiguous_aminoacid_behavior')
ambiguous_aminoacid_map_override = config.get("ambiguous_aminoacid_map_override")
ambiguous_aminoacid_behavior = config.get("ambiguous_aminoacid_behavior")
ambiguous_aminoacid_map = AMBIGUOUS_AMINOACID_MAP.copy()

for aminoacid, standard_aminoacids in ambiguous_aminoacid_map_override.items():
ambiguous_aminoacid_map[aminoacid] = standard_aminoacids

if ambiguous_aminoacid_behavior == 'raise_error':
if ambiguous_aminoacid_behavior == "raise_error":
if any(aminoacid in ambiguous_aminoacid_map for aminoacid in protein):
raise ValueError("Ambiguous amino acids found in protein sequence.")
elif ambiguous_aminoacid_behavior == 'standardize_deterministic':
elif ambiguous_aminoacid_behavior == "standardize_deterministic":
protein = "".join(
ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0] for aminoacid in protein
ambiguous_aminoacid_map.get(aminoacid, [aminoacid])[0]
for aminoacid in protein
)
elif ambiguous_aminoacid_behavior == 'standardize_random':
elif ambiguous_aminoacid_behavior == "standardize_random":
protein = "".join(
random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid])) for aminoacid in protein
random.choice(ambiguous_aminoacid_map.get(aminoacid, [aminoacid]))
for aminoacid in protein
)
else:
raise ValueError(f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}.")
raise ValueError(

Check warning on line 200 in CodonTransformer/CodonData.py

View check run for this annotation

Codecov / codecov/patch

CodonTransformer/CodonData.py#L200

Added line #L200 was not covered by tests
f"Invalid ambiguous_aminoacid_behavior: {ambiguous_aminoacid_behavior}."
)

# Check for sequence validity
if any(aminoacid not in AMINO_ACIDS + STOP_SYMBOLS for aminoacid in protein):
Expand Down
20 changes: 19 additions & 1 deletion CodonTransformer/CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from CodonTransformer.CodonData import get_merged_seq
from CodonTransformer.CodonUtils import (
AMINO_ACID_TO_INDEX,
INDEX2TOKEN,
NUM_ORGANISMS,
ORGANISM2ID,
Expand All @@ -41,6 +42,7 @@ def predict_dna_sequence(
temperature: float = 0.2,
top_p: float = 0.95,
num_sequences: int = 1,
match_protein: bool = False,
) -> Union[DNASequencePrediction, List[DNASequencePrediction]]:
"""
Predict the DNA sequence(s) for a given protein using the CodonTransformer model.
Expand Down Expand Up @@ -83,6 +85,9 @@ def predict_dna_sequence(
The value must be a float between 0 and 1. Defaults to 0.95.
num_sequences (int, optional): The number of DNA sequences to generate. Only applicable
when deterministic is False. Defaults to 1.
match_protein (bool, optional): Ensures the predicted DNA sequence is translated
to the input protein sequence by sampling from only the respective codons of
given amino acids. Defaults to False.

Returns:
Union[DNASequencePrediction, List[DNASequencePrediction]]: An object or list of objects
Expand Down Expand Up @@ -198,6 +203,19 @@ def predict_dna_sequence(
# Get the model predictions
output_dict = model(**tokenized_input, return_dict=True)
logits = output_dict.logits.detach().cpu()
logits = logits[:, 1:-1, :] # Remove [CLS] and [SEP] tokens

# Mask the logits of codons that do not correspond to the input protein sequence
if match_protein:
possible_tokens_per_position = [
AMINO_ACID_TO_INDEX[token[0]] for token in merged_seq.split(" ")
]
mask = torch.full_like(logits, float("-inf"))

for pos, possible_tokens in enumerate(possible_tokens_per_position):
mask[:, pos, possible_tokens] = 0

logits = mask + logits

predictions = []
for _ in range(num_sequences):
Expand All @@ -211,7 +229,7 @@ def predict_dna_sequence(

predicted_dna = list(map(INDEX2TOKEN.__getitem__, predicted_indices))
predicted_dna = (
"".join([token[-3:] for token in predicted_dna[1:-1]]).strip().upper()
"".join([token[-3:] for token in predicted_dna]).strip().upper()
)

predictions.append(
Expand Down
52 changes: 36 additions & 16 deletions CodonTransformer/CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@
# Index-to-token mapping, reverse of TOKEN2INDEX
INDEX2TOKEN: Dict[int, str] = {i: c for c, i in TOKEN2INDEX.items()}

# Dictionary mapping each amino acid and stop symbol to indices of codon tokens that translate to it
AMINO_ACID_TO_INDEX = {
aa: sorted(
[i for t, i in TOKEN2INDEX.items() if t[0].upper() == aa and t[-3:] != "unk"]
)
for aa in (AMINO_ACIDS + STOP_SYMBOLS)
}


# Mask token mapping
TOKEN2MASK: Dict[int, int] = {
0: 0,
Expand Down Expand Up @@ -550,14 +559,15 @@
"""
Abstract base class for managing configuration settings.
"""

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
if exc_type is not None:
print(f"Exception occurred: {exc_type}, {exc_value}, {traceback}")
self.reset_config()

@abstractmethod
def reset_config(self) -> None:
"""Reset the configuration to default values."""
Expand Down Expand Up @@ -601,7 +611,8 @@
def validate_inputs(self, key: str, value: Any) -> None:
"""Validate the inputs for the configuration."""
pass



class ProteinConfig(ConfigManager):
"""
A class to manage configuration settings for protein sequences.
Expand All @@ -613,6 +624,7 @@
_instance (Optional[ConfigManager]): The singleton instance of the ConfigManager.
_config (Dict[str, Any]): The configuration dictionary.
"""

_instance = None

def __new__(cls):
Expand All @@ -626,36 +638,44 @@
cls._instance = super(ProteinConfig, cls).__new__(cls)
cls._instance.reset_config()
return cls._instance

def validate_inputs(self, key: str, value: Any) -> None:
"""
Validate the inputs for the configuration.

Args:
key (str): The key to validate.
value (Any): The value to validate.

Raises:
ValueError: If the value is invalid.
TypeError: If the value is of the wrong type.
"""
if key == 'ambiguous_aminoacid_behavior':
if key == "ambiguous_aminoacid_behavior":
if value not in [
'raise_error',
'standardize_deterministic',
'standardize_random'
"raise_error",
"standardize_deterministic",
"standardize_random",
]:
raise ValueError(f"Invalid value for ambiguous_aminoacid_behavior: {value}.")
elif key == 'ambiguous_aminoacid_map_override':
raise ValueError(

Check warning on line 660 in CodonTransformer/CodonUtils.py

View check run for this annotation

Codecov / codecov/patch

CodonTransformer/CodonUtils.py#L660

Added line #L660 was not covered by tests
f"Invalid value for ambiguous_aminoacid_behavior: {value}."
)
elif key == "ambiguous_aminoacid_map_override":
if not isinstance(value, dict):
raise TypeError(f"Invalid type for ambiguous_aminoacid_map_override: {value}.")
raise TypeError(

Check warning on line 665 in CodonTransformer/CodonUtils.py

View check run for this annotation

Codecov / codecov/patch

CodonTransformer/CodonUtils.py#L665

Added line #L665 was not covered by tests
f"Invalid type for ambiguous_aminoacid_map_override: {value}."
)
for ambiguous_aminoacid, aminoacids in value.items():
if not isinstance(aminoacids, list):
raise TypeError(f"Invalid type for aminoacids: {aminoacids}.")
if not aminoacids:
raise ValueError(f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list.")
raise ValueError(

Check warning on line 672 in CodonTransformer/CodonUtils.py

View check run for this annotation

Codecov / codecov/patch

CodonTransformer/CodonUtils.py#L672

Added line #L672 was not covered by tests
f"Override for aminoacid '{ambiguous_aminoacid}' cannot be empty list."
)
if ambiguous_aminoacid not in AMBIGUOUS_AMINOACID_MAP:
raise ValueError(f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}")
raise ValueError(

Check warning on line 676 in CodonTransformer/CodonUtils.py

View check run for this annotation

Codecov / codecov/patch

CodonTransformer/CodonUtils.py#L676

Added line #L676 was not covered by tests
f"Invalid amino acid in ambiguous_aminoacid_map_override: {ambiguous_aminoacid}"
)
else:
raise ValueError(f"Invalid configuration key: {key}")

Expand All @@ -664,8 +684,8 @@
Reset the configuration to the default values.
"""
self._config = {
'ambiguous_aminoacid_behavior': 'standardize_random',
'ambiguous_aminoacid_map_override': {}
"ambiguous_aminoacid_behavior": "standardize_random",
"ambiguous_aminoacid_map_override": {},
}


Expand Down
3 changes: 2 additions & 1 deletion tests/test_CodonData.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
build_amino2codon_skeleton,
get_amino_acid_sequence,
is_correct_seq,
read_fasta_file,
preprocess_protein_sequence,
read_fasta_file,
)
from CodonTransformer.CodonUtils import ProteinConfig


class TestCodonData(unittest.TestCase):
def test_preprocess_protein_sequence(self):
with ProteinConfig() as config:
Expand Down
97 changes: 97 additions & 0 deletions tests/test_CodonPrediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,103 @@ def test_predict_dna_sequence_multi_diversity(self):
translated_protein = get_amino_acid_sequence(prediction.predicted_dna[:-3])
self.assertEqual(translated_protein, protein_sequence)

def test_predict_dna_sequence_match_protein_repetitive(self):
"""Test that match_protein=True correctly handles highly repetitive and unconventional sequences."""
test_sequences = (
"QQQQQQQQQQQQQQQQ_",
"KRKRKRKRKRKRKRKR_",
"PGPGPGPGPGPGPGPG_",
"DEDEDEDEDEDEDEDEDE_",
"M_M_M_M_M_",
"MMMMMMMMMM_",
"WWWWWWWWWW_",
"CCCCCCCCCC_",
"MWCHMWCHMWCH_",
"Q_QQ_QQQ_QQQQ_",
"MWMWMWMWMWMW_",
"CCCHHHMMMWWW_",
"_",
"M_",
"MGWC_",
)

organism = "Homo sapiens"

for protein_sequence in test_sequences:
# Generate sequence with match_protein=True
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=20, # High temperature to test protein matching
match_protein=True,
)

dna_sequence = result.predicted_dna
translated_protein = get_amino_acid_sequence(dna_sequence)

self.assertEqual(
translated_protein,
protein_sequence,
f"Translated protein must match original when match_protein=True. Failed for sequence: {protein_sequence}",
)

def test_predict_dna_sequence_match_protein_rare_amino_acids(self):
"""Test match_protein with rare amino acids that have limited codon options."""
# Methionine (M) and Tryptophan (W) have only one codon each
# While Leucine (L) has 6 codons - testing contrast
protein_sequence = "MWLLLMWLLL"
organism = "Escherichia coli general"

# Run multiple predictions
results = []
num_iterations = 10

for _ in range(num_iterations):
result = predict_dna_sequence(
protein=protein_sequence,
organism=organism,
device=self.device,
tokenizer=self.tokenizer,
model=self.model,
deterministic=False,
temperature=20, # High temperature to test protein matching
match_protein=True,
)
results.append(result.predicted_dna)

# Check all sequences
for dna_sequence in results:
# Verify M always uses ATG
m_positions = [0, 5] # Known positions of M in sequence
for pos in m_positions:
self.assertEqual(
dna_sequence[pos * 3 : (pos + 1) * 3],
"ATG",
"Methionine must use ATG codon.",
)

# Verify W always uses TGG
w_positions = [1, 6] # Known positions of W in sequence
for pos in w_positions:
self.assertEqual(
dna_sequence[pos * 3 : (pos + 1) * 3],
"TGG",
"Tryptophan must use TGG codon.",
)

# Verify all L codons are valid
l_positions = [2, 3, 4, 7, 8, 9] # Known positions of L in sequence
l_codons = [dna_sequence[pos * 3 : (pos + 1) * 3] for pos in l_positions]
valid_l_codons = {"TTA", "TTG", "CTT", "CTC", "CTA", "CTG"}
self.assertTrue(
all(codon in valid_l_codons for codon in l_codons),
"All Leucine codons must be valid",
)


if __name__ == "__main__":
unittest.main()
40 changes: 13 additions & 27 deletions tests/test_CodonUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,23 @@
class TestCodonUtils(unittest.TestCase):
def test_config_manager(self):
with ProteinConfig() as config:
config.set(
"ambiguous_aminoacid_behavior",
"standardize_deterministic"
)
config.set("ambiguous_aminoacid_behavior", "standardize_deterministic")
self.assertEqual(
config.get("ambiguous_aminoacid_behavior"),
"standardize_deterministic"
)
config.set(
"ambiguous_aminoacid_map_override",
{"X": ["A", "G"]}
config.get("ambiguous_aminoacid_behavior"), "standardize_deterministic"
)
config.set("ambiguous_aminoacid_map_override", {"X": ["A", "G"]})
self.assertEqual(
config.get("ambiguous_aminoacid_map_override"),
{"X": ["A", "G"]}
config.get("ambiguous_aminoacid_map_override"), {"X": ["A", "G"]}
)
config.update({
"ambiguous_aminoacid_behavior": "raise_error",
"ambiguous_aminoacid_map_override": {"X": ["A", "G"]},
})
self.assertEqual(
config.get("ambiguous_aminoacid_behavior"),
"raise_error"
config.update(
{
"ambiguous_aminoacid_behavior": "raise_error",
"ambiguous_aminoacid_map_override": {"X": ["A", "G"]},
}
)
self.assertEqual(config.get("ambiguous_aminoacid_behavior"), "raise_error")
self.assertEqual(
config.get("ambiguous_aminoacid_map_override"),
{"X": ["A", "G"]}
config.get("ambiguous_aminoacid_map_override"), {"X": ["A", "G"]}
)
try:
config.set("invalid_key", "invalid_value")
Expand All @@ -53,13 +43,9 @@ def test_config_manager(self):
pass
with ProteinConfig() as config:
self.assertEqual(
config.get("ambiguous_aminoacid_behavior"),
"standardize_random"
)
self.assertEqual(
config.get("ambiguous_aminoacid_map_override"),
{}
config.get("ambiguous_aminoacid_behavior"), "standardize_random"
)
self.assertEqual(config.get("ambiguous_aminoacid_map_override"), {})

def test_load_python_object_from_disk(self):
test_obj = {"key1": "value1", "key2": 2}
Expand Down