Anthony Zhou, Zijie Li, Michael Schneier, John R Buchanan Jr, Amir Barati Farimani. (Paper)
If you find our work useful, feel free to cite us:
@misc{zhou2024text2pde,
title={Text2PDE: Latent Diffusion Models for Accessible Physics Simulation},
author={Anthony Zhou and Zijie Li and Michael Schneier and John R Buchanan Jr and Amir Barati Farimani},
year={2024},
eprint={2410.01153},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.01153},
}
To install requirements:
conda-env create -n my_env -f=environment.yml
Alternatively, to manually install required packages:
conda create -n "my_env"
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 torchdata=0.6.0 -c pytorch -c nvidia
conda install pytorch-scatter -c pyg
conda install lightning -c conda-forge
pip install transformers
pip install wandb h5py tensorboard einops open3d sparse-dot-mkl timm
Some packages are just easier to install with pip vs conda.
Optional installs for image captioning, FLOPs profiling, and text evaluation:
pip install scikit-image deepspeed phiflow
If you cannot install pytorch<=2.0.1, please refer to the Compatibility section, as some libraries require this version.
Full datasets are available here.
Please refer to the dataset directory for a description of the raw data and dataloading.
Pretrained models are available here.
The pretrained models are:
- Autoencoders:
- ae_cylinder.ckpt : autoencoder trained to compress cylinder mesh data across 25 timesteps. Does not use GAN or LPIPS.
- ae_ns2D.ckpt: autoencoder trained to compress smoke buoyancy data (48x128x128). Does not use GAN or LPIPS.
- LDMs:
- cylinder flow
- ldm_DiT_FF_cylinder.ckpt: ldm model trained to sample a cylinder flow solution conditioned on the first frame
- ldm_DiTSmall_FF_cylinder.ckpt: same as previous, just smaller DiT size.
- ldm_DiT_text_cylinder.ckpt: ldm model trained to sample a cylinder flow solution conditioned on a text prompt
- ldm_DiTSmall_text_cylinder.ckpt: same as previous, just smaller DiT size.
- ns2D
- ldm_DiT_FF_ns2D.ckpt: ldm model trained to sample a smoke buoyancy solution conditioned on the first frame
- ldm_DiTSmall_FF_ns2D.ckpt: same as previous, just smaller DiT size.
- ldm_DiTLarge_FF_ns2D.ckpt: same as previous, just large DiT size.
- ldm_DiT_text_ns2D.ckpt: ldm model trained to sample a smoke buoyancy solution conditioned on a text prompt
- ldm_DiTSmall_text_ns2D.ckpt: same as previous, just smaller DiT size.
- ldm_DiTLarge_text_ns2D.ckpt: same as previous, just large DiT size.
For more information about the relevant training parameters, see the configs directory.
Workflow for training a model:
- Setup environment
- Download a dataset
- Make a log directory
- Setup wandb
- Set paths to dataset, normalization stats, logging directory
- For LDM training: download a pretrained AE or train an AE
- Recommended training hardware:
- LDM_small: 1xA100
- LDM_medium: 4xA100
- LDM_large: 4xA100 80GB
Workflow for inference:
- Download a pretrained LDM and AE model
- If doing text re-solving or FLOPs profiling: setup PhiFlow/Deepspeed
- Set paths to pretrained model and logging directory
- Approximate memory requirements for inference w/ batch size = 1:
- LDM_small: 6.7 GB
- LDM_medium: 12.5 GB
- LDM_large: 35 GB
To train an autoencoder (supports KL regularization, GAN, LPIPS):
python train_AE.py --config=path/to/config
Example: train an autoencoder for the cylinder dataset without GAN or LPIPS
python train_AE.py --config=configs/cylinder/ae/ae_mesh.yaml
To train a latent diffusion model:
python train_ldm.py --config=path/to/config
Example: train a small-size LDM for the NS2D dataset with text conditioning
python train_ldm.py --config=configs/ns2D/ldm/text/ldm_DiTSmall_text.yaml
To train a baseline model for the cylinder flow problem:
python train_{gino/gnn/oformer}.py --config=path/to/config
To train a baseline model for the smoke buoyancy (ns2D) problem:
python train_{ns2D/acdm}.py --config=path/to/config
Note that the FNO, Unet, and Resnet models all use the same script (train_ns2D.py).
Example: train a FNO baseline on the NS2D dataset
python train_ns2D.py --config=configs/ns2D/baselines/fno.yaml
To generate reconstructed samples on the validation set and evaluate a mean reconstruction loss on cuda:0:
python validate_AE.py --config=path/to/config --device="cuda:0"
For baselines (not including ACDM), to generate predicted samples on the validation set and evaluate a mean prediction loss:
python validate_{cylinder/ns2D}.py --config=path/to/config
For LDM and ACDM models, to conditionally sample from the validation set and evaluate and mean prediction loss:
python validate_ldm.py --config=path/to/config
Example: sample a pretrained medium-size LDM by conditioning on text in a cylinder validation set.
python validate_ldm.py --config=configs/cylinder/ldm/text/ldm_DiT_text.yaml
Configs passed to validation scripts can also be used to generate a corresponding FLOPs profile:
python profile_flops.py --config=path/to/config
Example: Profile FLOPs of large LDM model on NS2D with text conditioning:
python profile_flops.py --config=configs/ns2D/ldm/text/ldm_FiTLarge_text.yaml
Scripts for captioning PDE simulations are in the text directory, along with some details.
Some parts of the code relies on Open3D. Specifically, Open3D requires a version of torch <=2.0.1; this option can be disabled in the config files if the installation is not compatible, and the codebase can fall back to a native PyTorch implementation. This is slower and requires more memory, but can be set with the flag use_open3d=False in all configs.
Additionally, there are certain reports of FFT failing for pytorch-cuda <=11.7 (issue). Only the FNO and GINO baselines make use of FFT.
Lastly, the smoke buoyancy problem relies on the torchdata and datapipes package, which will be deprecated in the future. This may also cause compatibility issues with newer versions of torch (>=2.0.1), specifically:
File "/home/anaconda3/envs/env-name/lib/python3.11/site-packages/torchdata/datapipes/iter/util/cacheholder.py", line 24, in <module>
from torch.utils._import_utils import dill_available
ModuleNotFoundError: No module named 'torch.utils._import_utils'
A workaround is to define a function to always return false:
cd /path-to-conda-env/lib/python3.11/site-packages/torch/utils
echo "def dill_available(): return False" > _import_utils.py
For those leveraging multiprocessing on a SLURM cluster, there are some additional considerations:
- If you plan on training the model with text capabilities, it is recommended to manually download the pretrained LLM weights (RoBERTa) and load them locally, as downloading weights on the fly may cause the script to hang. RoBERTa weights can be found here.
# On the fly. Might cause the script to hang.
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
model = RobertaModel.from_pretrained("FacebookAI/roberta-base")
# Loading weights locally. Safer option.
tokenizer = AutoTokenizer.from_pretrained(cache_path, local_files_only=True)
model = RobertaModel.from_pretrained(cache_path, local_files_only=True)
- You may need to limit the number of train/val batches per epoch if using datapipes. In some DDP cases, having incomplete batches can cause GPUs to hang. Issue
Due to the size of 3D turbulence datasets (9 TB) and models, they are not on huggingface but will be happily provided upon request. 3D Turbulence datasets are from From Zero to Turbulence: Generative Modeling for 3D Flow Simulation, which can also be downloaded according to the repo. 3D turbulence visualizations will also require vape4d for rendering. Feel free to reach out for assistance in running the code or training a model.