Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correcting Acquisition Function Calculation #13

Merged
merged 7 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lambo/acquisitions/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,19 @@ def forward(self, X: array) -> Tensor:
baseline_X = self._X_baseline
baseline_X = baseline_X.expand(*X.shape[:-2], -1, -1)
X_full = torch.cat([baseline_X, X], dim=-2)
q = X.shape[-2]
else:
baseline_X = copy(self.X_baseline_string) # ensure contiguity
baseline_X.resize(
baseline_X.shape[:-(X.ndim)] + X.shape[:-1] + baseline_X.shape[-1:]
)
X_full = concatenate([baseline_X, X], axis=-1)
q = X.shape[-1]
# Note: it is important to compute the full posterior over `(X_baseline, X)``
# to ensure that we properly sample `f(X)` from the joint distribution `
# `f(X_baseline, X) ~ P(f | D)` given that we can already fixed the sampled
# function values for `f(X_baseline)`
posterior = self.model.posterior(X_full)
q = X.shape[-2]
self._set_sampler(q=q, posterior=posterior)
samples = self.sampler(posterior)[..., -q:, :]
# add previous nehvi from pending points
Expand Down
7 changes: 0 additions & 7 deletions lambo/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@

from lambo.utils import StringSubstitution, StringDeletion, StringInsertion, FoldxMutation


def apply_mutation(base_seq, mut_pos, mut_res, tokenizer):
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1]
mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):])
return mut_seq


def pdb_to_residues(pdb_path, chain_id='A'):
"""
:param pdb_path: path to pdb file (str or Path)
Expand Down
2 changes: 1 addition & 1 deletion lambo/models/base_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class BaseSurrogate(torch.nn.Module):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float
dtype = torch.double

def _set_transforms(self, tokenizer, max_shift, mask_size, train_prepend=None):
# convert from string to LongTensor of token indexes
Expand Down
21 changes: 4 additions & 17 deletions lambo/optimizers/lambo.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,7 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref
else:
raise ValueError

# import pdb; pdb.set_trace()
lat_acq_vals = acq_fn(pooled_features.unsqueeze(-2))
# lat_acq_vals = acq_fn(pooled_features.unsqueeze(0))
lat_acq_vals = acq_fn(pooled_features.unsqueeze(0))
loss = -lat_acq_vals.mean() + self.entropy_penalty * logit_entropy.mean()

if self.optimize_latent:
Expand All @@ -322,31 +320,20 @@ def optimize(self, candidate_pool, pool_targets, all_seqs, all_targets, log_pref

tgt_seqs = tokens_to_str(tgt_tok_idxs, self.encoder.tokenizer)
with torch.no_grad():
act_acq_vals = acq_fn(tgt_seqs[..., None])
# act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item()

is_improved = (act_acq_vals >= best_scores)
best_scores = torch.where(is_improved, act_acq_vals, best_scores)
best_seqs = np.where(is_improved.cpu().numpy(), tgt_seqs, best_seqs)
# best_scores[is_improved] = act_acq_vals[is_improved]
# best_seqs[is_improved] = tgt_seqs[is_improved]

with torch.no_grad():
batch_acq_val = acq_fn(best_seqs[None, :]).mean().item()
curr_score = -1.0 * batch_acq_val
act_acq_vals = acq_fn(tgt_seqs[None, :]).mean().item()

best_score, best_step, _, stop = check_early_stopping(
model=None,
best_score=best_score,
best_epoch=best_step,
best_weights=None,
curr_score=curr_score,
curr_score=-act_acq_vals,
curr_epoch=step_idx + 1,
patience=self.patience,
save_weights=False,
)
if (step_idx + 1) == best_step:
# best_seqs = tgt_seqs.copy()
best_seqs = tgt_seqs.copy()
best_entropy = logit_entropy.mean().item()
if stop:
print(f"Early stopping at step {step_idx + 1}")
Expand Down
1 change: 0 additions & 1 deletion lambo/tasks/surrogate_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def _evaluate(self, x, out, *args, **kwargs):
cand_idx, mut_pos, mut_res_idx, op_idx = query_pt
op_type = self.op_types[op_idx]
base_seq = self.candidate_pool[cand_idx].mutant_residue_seq
mut_pos = mut_pos % len(base_seq)
mut_res = self.tokenizer.sampling_vocab[mut_res_idx]
mutant_seq = apply_mutation(base_seq, mut_pos, mut_res, op_type, self.tokenizer)
candidates.append(mutant_seq)
Expand Down
3 changes: 2 additions & 1 deletion lambo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ def tokens_to_str(tok_idx_array, tokenizer):


def apply_mutation(base_seq, mut_pos, mut_res, op_type, tokenizer):
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1]
tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")
mut_pos = mut_pos % len(tokens)

if op_type == 'sub':
mut_seq = "".join(tokens[:mut_pos] + [mut_res] + tokens[(mut_pos + 1):])
Expand Down