Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inifinite decoding #7

Closed
vandedok opened this issue Oct 17, 2022 · 5 comments
Closed

Inifinite decoding #7

vandedok opened this issue Oct 17, 2022 · 5 comments
Labels
wontfix This will not be worked on

Comments

@vandedok
Copy link

vandedok commented Oct 17, 2022

Hello,

I am using JTVAE from your repo and faced some problems. I was doing random search over the latent space, but for some points the decoding is not stopping. It seems that the problem is in the in the enum_assemble function from chemutils.py file (weighted-retraining/weighted_retraining/chem/jtnn/chemutils.py): the internal recursive search function either doesn't converge or converge too slow.

This is the code to reproduce the issue ( I am assuming your conda environment present and the preprocessing is done):

from pathlib import Path
import torch
from weighted_retraining.chem.chem_data import Vocab
from weighted_retraining.chem.chem_model import JTVAE

weighted_retraining_dir = Path("../../weighted-retraining/")

with open(weighted_retraining_dir/"data/chem/zinc/orig_model/vocab.txt") as f:
    vocab = Vocab([x.strip() for x in f.readlines()])

pretrained_model_file = str(weighted_retraining_dir / "assets/pretrained_models/chem.ckpt")

model = JTVAE.load_from_checkpoint(
    pretrained_model_file, vocab=vocab
)

model.beta = model.hparams.beta_final 

bad_latents = torch.tensor([[ -6.2894, -45.1619,  11.9765,  11.6767,  37.5106, -24.5908, -25.9559,
          40.9180,  18.4495, -30.9735, -10.9526, -31.6441, -49.6980, -36.1106,
          12.7674,   6.1417, -44.0838, -34.6051,  -9.2435,  47.8085,  41.7193,
         -44.4102,  15.3359, -38.5631,   7.2546, -48.9917,  16.5505, -45.4565,
         -49.4582,  11.6730,  13.2594, -37.0152,  39.9500, -39.3020, -16.2288,
          23.3959, -36.6568, -48.8145,  13.4714,  19.7008,  30.5797, -42.0284,
         -28.3188, -29.0985,  18.7675,  -7.5038,  10.2781,   1.0429, -24.5770,
         -15.5115,  10.9733, -18.1378, -34.5497, -25.7164, -21.9990,  14.0688]])

with torch.no_grad():
    smiles = model.decode_deterministic(bad_latents)

Did you face anything similar? Do you know what can be done in such situations?

PS Thanks for your JTVAE implementation, it's the most convenient one I have found

@AustinT
Copy link
Collaborator

AustinT commented Oct 17, 2022

Hello, happy to hear that you found this implementation useful. I did not encounter this particular issue but am not surprised by it. I don't have a fix available unfortunately. My best suggestion would just be to check how many solutions enum_assemble iterates through, and perhaps force it to terminate after a certain maximum number of iterations. The effective behaviour here would be to return None at these points rather than just hanging for a long time. Would this be a good workaround for you?

@vandedok
Copy link
Author

vandedok commented Oct 21, 2022

Thank you for the reply!

Do you mean that I need to check how many search iterations does the search function inside 'enum_assemble` does?

Anyways, this may work. However I wonder why didn't you faced the problem -- it didn't take that many iterations to find these points. Also ti seems that bayesian optimisation find those faster than random search.

Btw, sometimes I do get None as decoding result. However I didn't check what causes that.

@AustinT
Copy link
Collaborator

AustinT commented Oct 23, 2022

Thank you for the reply!

Do you mean that I need to check how many search iterations does the search function inside 'enum_assemble` does?

Anyways, this may work. However I wonder why didn't you faced the problem -- it didn't take that many iterations to find these points. Also ti seems that bayesian optimisation find those faster than random search.

Btw, sometimes I do get None as decoding result. However I didn't check what causes that.

Yes that is what I meant. Maybe you are finding these points because the python libraries have been updated? I last ran this code ~2 years ago.

Let me know if this solution is helpful and I will close the issue. However, I don't really want to make changes to this codebase at the moment since I think it is important to be able to use it to reproduce the results of our paper, and changes to the code may change the behaviour.

@vandedok
Copy link
Author

Okay, I tried to limit the search calls in a crude way:

def enum_assemble(node, neighbors, prev_nodes=[], prev_amap=[], max_search_calls=20):
    ...
    ...
    n_search_calls = 0
    def search(cur_amap, depth):
        nonlocal n_search_calls
        n_search_calls += 1
        ...
        ...
        for new_amap in candidates:
            if n_search_calls < max_search_calls:
                search(new_amap, depth + 1)

It sort of worked: now the encoding ends on meaningful time for the latents I've presented, but the output looks like garbage:
Screenshot from 2022-10-28 18-09-27

However rdkit can even compute the LogP for this thing.

Would you like me to create a pull request or we just live it be?

@AustinT
Copy link
Collaborator

AustinT commented Oct 30, 2022

Thanks for the code snippet @vandedok ! If the search does not complete then it makes sense that a sub-optimal molecule may be returned, as you've demonstrated. I think this is just a limitation of the JT-VAE.

My suggestion would just be to leave this, instead of submitting a PR. I will make a note of it on the README for this project. Feel free to re-open the issue if you feel that is appropriate.

@AustinT AustinT closed this as completed Oct 30, 2022
@AustinT AustinT added the wontfix This will not be worked on label Oct 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on
Projects
None yet
Development

No branches or pull requests

2 participants