Skip to content

A deep neural network with hybrid architecture (EGNN + Transformer) for molecular properties prediction.

Notifications You must be signed in to change notification settings

Curtis-Wu/Equivariant-Graph-Transformer

Repository files navigation

Molecular Potential Prediction using Pre-trained EGNN and Transformer-Encoder

Alt Text!
This repository holds an Equivariant Graph Neural Network (EGNN) + Transformer-Encoder model used for end-to-end ANI-1 molecular potential prediction. Instructions for training and evaluation of the model could be found in sections below.

The goal of this project is to achieve accurate molecular potential prediction for the ANI-1 data set. The model presented in this repository use a Pre-trained3 E(n) equivariant neural network1, which becomes invariant in our case when dealing with objects with static positions, as well as an transformer encoder to capture both the local and global interactions between the nodes to achieve accurate molecular properties predictions.

Note: The repository name was proposed for simplicity reasons. The model presented is not a Graph-Transformer in a traditional meaning, which utilizes transformer architecture to perform calculation directly on graph data (Nodes + Edges).

The complete process and workflow of data-processing, model architecture creation, model training and results with detailed documentation can be found in main.ipynb.

Sections

  1. Complete Workflow
  2. Parameter Details
  3. Environment Setup
  4. Model Training
  5. Model Evaluation

Model Training and Evaluation Workflow:

  1. Data Preparation: Place your dataset in the ./Data folder. Adjust the necessary parameters in the config.yaml file.
  2. Data Reading and Splitting: The model imports coordinates, atom species, and energies from the files in ./Data. It then divides the data into training/validation/test subsets as defined in the configuration.
  3. Pre-Processing: Additional pre-processing was executed including initial node embeddings by atom type, subtraction of self interaction energy from the total energy etc.
  4. Data Packaging: Organize the processed data into torch_geometric.loader.DataLoader to create a train_loader, which is then ready for the training process.
  5. Training Setup: Training function configured using parameters specified in the configuration file.
  6. Normalization: Scale the target data (y-values) with user specified scaling factor.
  7. Logging and Output: Set up a logging function, file writing, and TensorBoard writer for monitoring the training process.
  8. Model Training: Train the model batch-wise and save the model parameters with the lowest validation loss. For evaluation, the y-values are re-scaled back using the scaling value.
  9. Evaluation on Test Set: Use the best-performing model to evaluate the test set and analyze the results.

Parameter Details

    gpu: gpu                # Custom name for gpu device
    lr: 2e-4                # Maximum learning rate
    min_lr: 1e-7            # Minimum learning rate
    weight_decay: 0.0       # Weight decay param
    epochs: 20              # Epochs
    warmup_epochs: 0.1      # Ratio of warm up epochs
    patience_epochs: 0.9    # Ratio of pateince epochs
    log_every_n_steps: 250  # Frequency of logging

    # Load_model: None 
    load_model: models/pretrained_egnn.pth

    scale_value: 300        # Energy scaling factor
    normalize_energies: false # Energy normalization
    log_transformation: false # Log transformation on energies
    freeze_epochs: 0        # Freeze EGCl by epochs, depreciated and not deleted in code
    
    # EGNN/EGCL Parameters
    hidden_channels: 256    # Number of hidden_channels
    num_edge_feats: 0       # Number of additional edge features
    num_egcl: 2             # Number of EGCL layers
    act_fn: SiLU            # Activation function
    residual: True          # Residual calculation
    attention: True         # Graph Attention mechanism
    normalize: True         # Interatomic distance normalization
    cutoff: 4               # Interatomic distance curoff
    max_atom_type: 28       # Max atom types
    max_num_neighbors: 32   # Max number of neighborgoods
    static_coord: True      # Specify whether to update coord or not
    freeze_egcl: True       # Whether or not to freeze weights of egcls

    # Transformer-Encoder Parameters
    d_model: 256            # Embeddings for each token
    num_encoder: 1          # Number of encoder units
    num_heads: 8            # Number of self-attention heads
    num_ffn: 256            # Number of neurons in the feedforward MLP
    act_fn_ecd: ReLU        # Activation function for encoder MLP
    dropout_r: 0.1          # Dropout rate

    # Energy Head
    num_neuron: 512         # NUmber of neurons for the final energy head

    batch_size: 256         # Batch size
    num_workers: 8          # Number of workers for data loaders
    valid_size: 0.1         # Validation set size
    test_size: 0.1          # Test set size
    data_dir: './Data'      # Data directory
    seed: 42                # Random seed

Environment Setup

make create-env         # Create conda environment
conda activate EGTF_env # Activate conda environment
conda deactivate        # Deactivate the environment
make delete-env         # Delete the conda envrionment

If error was found during the installation of torch_geometric related packages, it could be due to outdated compilers. Consider updating your compiler version, for example on linux: conda update gxx_linux-64 gcc_linux-64.

Training

To train the model using custom data, place the data in the ./Data folder. Change the config.yaml file accordingly, and input the following into the command line:

conda activate EGTF_env # Activate environment
python3 train.py    # Run python training script
conda deactivate    # Deactivate the environment
make delete-env     # Delete the conda envrionment

Evaluation

To evaluate the model from a specific run using custom data, place the data in the ./Data_eval folder, and input the following into the command line:

conda activate EGTF_env # Activate environment
python3 evaluate.py Runs_savio/$SPECIFIC_RUN
conda deactivate    # Deactivate the environment
make delete-env     # Delete the conda envrionment

This will load the pre-trained model architecture, parameters, and normalizer from that specific run, and perform evaluation on 10% of the Data in Data_eval.

Reference/Acknowledgement:

  1. V. G. Satorras et al., E(n) Equivariant Graph Neural Networks. [Paper] [GitHub]
  2. A. Vaswani et al., Attention is All You Need. [Paper]
  3. Y. Wang et al., Denoise Pre-training on Non-equilibrium Molecules for Accurate and Transferable Neural Potentials.
    [Paper] [GitHub]
  4. J. S. Smith et al., ANI-1: An extensible neural network potential with DFT accuracy at force field computational cost.
    [Paper] [GitHub]
  5. J. S. Smith et al., ANI-1, A data set of 20 million calculated off-equilibrium conformations for organic molecules.
    [Paper]

About

A deep neural network with hybrid architecture (EGNN + Transformer) for molecular properties prediction.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published