Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lambo/acquisitions/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,19 @@ def forward(self, X: array) -> Tensor:
baseline_X = self._X_baseline
baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1)
X_full = torch.cat([baseline_X, X], dim=-2)
q = X.shape[-2]
else:
baseline_X = copy(self.X_baseline_string) # ensure contiguity
baseline_X.resize(
baseline_X.shape[:-(X.ndim)] + X.shape[:-1] + baseline_X.shape[-1:]
)
X_full = concatenate([baseline_X, X], axis=-1)
q = X.shape[-1]
# Note: it is important to compute the full posterior over `(X_baseline, X)``
# to ensure that we properly sample `f(X)` from the joint distribution `
# `f(X_baseline, X) ~ P(f | D)` given that we can already fixed the sampled
# function values for `f(X_baseline)`
posterior = self.model.posterior(X_full)
q = X.shape[-2]
self._set_sampler(q=q, posterior=posterior)
samples = self.sampler(posterior)[..., -q:, :]
# add previous nehvi from pending points
Expand Down
7 changes: 0 additions & 7 deletions lambo/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@

from lambo.utils import StringSubstitution, StringDeletion, StringInsertion, FoldxMutation


def apply_mutation(base_seq, mut_pos, mut_res, tokenizer):
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1]
mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):])
return mut_seq


def pdb_to_residues(pdb_path, chain_id='A'):
"""
:param pdb_path: path to pdb file (str or Path)
Expand Down
2 changes: 1 addition & 1 deletion lambo/models/base_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class BaseSurrogate(torch.nn.Module):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
dtype = torch.double

def _set_transforms(self, tokenizer, max_shift, mask_size, train_prepend=None):
# convert from string to LongTensor of token indexes
Expand Down
21 changes: 4 additions & 17 deletions lambo/optimizers/lambo.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,7 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref
else:
raise ValueError

# import pdb; pdb.set_trace()
lat_acq_vals = acq_fn(pooled_features.unsqueeze(-2))
# lat_acq_vals = acq_fn(pooled_features.unsqueeze(0))
lat_acq_vals = acq_fn(pooled_features.unsqueeze(0))
loss = -lat_acq_vals.mean() + self.entropy_penalty * logit_entropy.mean()

if self.optimize_latent:
Expand All @@ -322,31 +320,20 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref

tgt_seqs = tokens_to_str(tgt_tok_idxs, self.encoder.tokenizer)
with torch.no_grad():
act_acq_vals = acq_fn(tgt_seqs[..., None])
# act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item()

is_improved = (act_acq_vals >= best_scores)
best_scores = torch.where(is_improved, act_acq_vals, best_scores)
best_seqs = np.where(is_improved.cpu().numpy(), tgt_seqs, best_seqs)
# best_scores[is_improved] = act_acq_vals[is_improved]
# best_seqs[is_improved] = tgt_seqs[is_improved]

with torch.no_grad():
batch_acq_val = acq_fn(best_seqs[None, :]).mean().item()
curr_score = -1.0 * batch_acq_val
act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item()

best_score, best_step, _, stop = check_early_stopping(
model=None,
best_score=best_score,
best_epoch=best_step,
best_weights=None,
curr_score=curr_score,
curr_score=-act_acq_vals,
curr_epoch=step_idx + 1,
patience=self.patience,
save_weights=False,
)
if (step_idx + 1) == best_step:
# best_seqs = tgt_seqs.copy()
best_seqs = tgt_seqs.copy()
best_entropy = logit_entropy.mean().item()
if stop:
print(f"Early stopping at step {step_idx + 1}")
Expand Down
1 change: 0 additions & 1 deletion lambo/tasks/surrogate_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def _evaluate(self, x, out, *args, **kwargs):
cand_idx, mut_pos, mut_res_idx, op_idx = query_pt
op_type = self.op_types[op_idx]
base_seq = self.candidate_pool[cand_idx].mutant_residue_seq
mut_pos = mut_pos % len(base_seq)
mut_res = self.tokenizer.sampling_vocab[mut_res_idx]
mutant_seq = apply_mutation(base_seq, mut_pos, mut_res, op_type, self.tokenizer)
candidates.append(mutant_seq)
Expand Down
3 changes: 2 additions & 1 deletion lambo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def tokens_to_str(tok_idx_array, tokenizer):


def apply_mutation(base_seq, mut_pos, mut_res, op_type, tokenizer):
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1]
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")
mut_pos = mut_pos % len(tokens)

if op_type == 'sub':
mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):])
Expand Down