Skip to content

Commit

Permalink
Update VIPER_run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
maxall41 authored Jun 29, 2024
1 parent 6eee5b8 commit cfbc2a2
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions scripts/VIPER_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,38 @@

SEQ_COL = 'sequence'
SMILES_COL = 'smiles'
models = []

def std_dev_to_confidence(std):
low_thresh = 0.08
high_thresh = 0.04
if std < high_thresh:
return "HIGH"
elif std > low_thresh:
return "LOW"
else:
return "MEDIUM"

def gen_ankh(seq):
model, tokenizer = ankh.load_base_model()
model.eval()
model.cuda()
outputs = tokenizer.batch_encode_plus([list(seq)],
add_special_tokens=True,
padding=True,
is_split_into_words=True,
return_tensors="pt")
with torch.no_grad():
embeddings = model(input_ids=outputs['input_ids'].cuda(), attention_mask=outputs['attention_mask'].cuda())
embeddings = embeddings.last_hidden_state.squeeze()
return embeddings

def run(row):
global models
ankh = gen_ankh(row[SEQ_COL])
molformer = compute_molformer_emb(row[SMILES_COL])
ankh = ankh.unsqueeze(0)
molformer = molformer.unsqueeze(0)

# Run Model
results = []
for model in models:
Expand All @@ -29,15 +54,13 @@ def run(row):
row['confidence'] = std_dev_to_confidence(std)
return row




def exec(input='in.csv',output="out.csv"):
global models
df = pd.read_csv(input)
models = create_models()
df = df.apply(run,axis=1)
df.to_csv(output)
print("Output saved!")

if __name__ == '__main__':
fire.exec(exec)
fire.Fire(exec)

0 comments on commit cfbc2a2

Please sign in to comment.