Searching chemical space as described in:
Molecular De Novo Design through Deep Reinforcement Learning
The current version is a PyTorch implementation that differs in several ways from the original implementation described in the paper. This version works better in most situations and is better documented, but for the purpose of reproducing results from the paper refer to Release v1.0.1
Differences from implmentation in the paper:
- Written in PyTorch/Python3.6 rather than TF/Python2.7
- SMILES are encoded with token index rather than as a onehot of the index. An embedding matrix is then used to transform the token index to a feature vector.
- Scores are in the range (0,1).
- A regularizer that penalizes high values of total episodic likelihood is included.
- Sequences are only considered once, ie if the same sequence is generated twice in a batch only the first instance contributes to the loss.
- These changes makes the algorithm more robust towards local minima, means much higher values of sigma can be used if needed.
This package requires:
- Python 3.6
- PyTorch 0.1.12
- RDkit
- Scikit-Learn (for QSAR scoring function)
- tqdm (for training Prior)
- pexpect
To import and preprocess the QM9 dataset, run ./Import_QM9.py
. This will generate mols.smi
, a list of raw input SMILES.
To train a Prior starting with a SMILES file called mols.smi
:
-
First filter the SMILES and construct a vocabulary from the remaining sequences.
./data_structs.py mols.smi
. This generates data/mols_filtered.smi and data/Voc. -
Then use
./train_prior.py
to train the Prior. A pretrained Prior I have already trained with QM9's trainind data subset is included (Prior_IshaniExample.ckpt
).
To train an Agent using our Prior, use the ./main.py
script. A full list of input parameters is listed in main.py. Some key parameters:
- scoring-function: target property that the Agent is trying to optimize (eg. no sulphur, tanimoto similarity to target string, band gap within a target range)
- num-steps: number of steps to train Agent
- sigma-mode: mechanism for selecting sigma, tradeoff between prioritizing Prior likelihood and high performance on scoring function during training (eg. static sigma (original code), adaptive sigma, based on score uncertainty, sampling from levy distribution)
- sigma: scalar value controlling tradeoff between prioritizing Prior likelihood and high performance on scoring function during training. If sigma-mode is not 'static', this value is the starting sigma in the first training step.
- prior: which trained Prior model file to use while training Agent. By default this is 'Prior.ckpt' generated by Prior training process detailed above.
- agent: which model to use as the first instantiation of Agent model. By default this should be the same Prior model selected as the 'prior' input
An example run:
./main.py --scoring-function bandgap_range_soft --num-steps 1000 --sigma_mode static --sigma 20
Scores can be visualized during training using the Vizard bokeh app. The vizard_logger.py is used to log information (by default to data/logs) such as structures generated, average score, and network weights.
cd Vizard
./run.sh ../data/logs
- Open the browser at http://localhost:5006/Vizard
After training, results are stored in data/results under a folder with the run date and time. Outputs include:
resuls/Agent.ckpt
: final trained Agent modelresults/sampled
: SMILES generated by final Agent model, along with their and Prior likelihoodresults/training_log_novel.npy
: for the SMILES dataset generated during each training step, this is the percentage of SMILES that are unique from original QM9 datasetresults/taining_log_sa.npy
: for the SMILES dataset generated during each training step, these are the synthetic accessibility scoresresults/training_log_scores.npy
: Training score during each training stepresults/training_log_sigmas.npy
: Sigma value used in each training step; if sigma_mode is not static, this will change over the course of Agent trainingresults/training_log_valid.npy
: for the SMILES dataset generated during each training step, this is the percentage of SMILES that are syntactically valid