diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..cf9dc24 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..91260c2 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,36 @@ +# Contributing to JMP +We want to make contributing to this project as easy and transparent as +possible. + + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://bugbounty.meta.com/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## Coding Style +* 2 spaces for indentation rather than tabs +* 80 character line length + +## License +By contributing to JMP, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..fdd1be6 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,400 @@ + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..096265b --- /dev/null +++ b/README.md @@ -0,0 +1,297 @@ + +# Joint Multi-domain Pre-Training (JMP) + +Please see our [website](https://nima.sh/jmp) for more information about this work. + +This repository contains the code for the paper ["From Molecules to Materials: Pre-training Large Generalizable Models for Atomic Property Prediction"](https://openreview.net/forum?id=PfPnugdxup). + + + +This repository is deprecated and is no longer actively maintained. We are currently working on integrating the functionality of this repository into the official [Open Catalyst Project repository](https://github.com/Open-Catalyst-Project/ocp/). +If you have any questions or concerns with respect to this repository, please feel free to create a github issue or reach out to us via email. Please send an email to [Nima Shoghi](mailto:ns@nima.sh) and CC [Brandon Wood](mailto:bmwood@meta.com). + + + + +## Table of Contents +- [Overview](#overview) +- [Results](#results) + - [Small Molecule Datasets: QM9 and rmD17](#small-molecule-datasets-qm9-and-rmd17) + - [Materials Datasets: MatBench and QMOF](#materials-datasets-matbench-and-qmof) + - [Large Molecule Datasets: MD22 and SPICE](#large-molecule-datasets-md22-and-spice) +- [Installation](#installation) +- [Datasets](#datasets) + - [Pre-training Datasets](#pre-training-datasets) + - [ANI-1x Dataset](#ani-1x-dataset) + - [Transition-1x Dataset](#transition-1x-dataset) + - [Fine-tuning Datasets](#fine-tuning-datasets) +- [Pre-trained Checkpoints](#pre-trained-checkpoints) +- [Training Models](#training-models) +- [License](#license) +- [Citation](#citation) + + +## Overview + +![Main Figure](images/README/main_figure.png) + +In this work, we introduce Joint Multi-domain Pre-training (JMP), a supervised pre-training strategy that simultaneously trains on multiple datasets from different chemical domains, treating each dataset as a unique pre-training task within a multi-task framework. Our combined training dataset consists of ~120M systems from OC20, OC22, ANI-1x, and Transition-1x. + +The key contributions of this work are: + +1. We demonstrate JMP's powerful generalization ability by evaluating its fine-tuning performance across a diverse benchmark suite spanning small molecules, large molecules, and materials. JMP consistently outperforms training from scratch and sets or matches the state-of-the-art on 34 out of 40 fine-tuning benchmarks. + +2. We show that JMP enables efficient scaling to larger models that would normally overfit if trained from scratch on small datasets. Pre-training acts as a strong regularizer, allowing us to train a model with 235M parameters that sets new state-of-the-art performance on multiple low-data benchmarks. + +3. We conduct a detailed analysis of JMP's computational requirements. While expensive upfront, we show JMP's pre-training cost is recovered by enabling over 12x faster fine-tuning compared to training from scratch. + +By pre-training large models on diverse chemical data, we believe JMP represents an important step towards the goal of a universal ML potential for chemistry. The continued growth of available data and compute power will only improve JMP's ability to learn transferable atomic representations. + +## Results + +JMP demonstrates an average improvement of 59% over training from scratch, and matches or sets state-of-the-art on 34 out of 40 tasks. Our work highlights the potential of pre-training strategies that utilize diverse data to advance property prediction across chemical domains, especially for low-data tasks. + +### Small Molecule Datasets: QM9 and rmD17 +![QM9](images/README/qm9.png) + +![rMD17](images/README/rmd17.png) + +### Materials Datasets: MatBench and QMOF +![Materials](images/README/materials.png) + +### Large Molecule Datasets: MD22 and SPICE +![Large Molecules](images/README/large_molecules.png) + +## Installation + +First, clone the repository and navigate to the root directory: + +```bash +git clone https://github.com/facebookresearch/JMP.git +cd JMP +``` + +Then, set up the conda environment, as provided in the `environment.yml` file. To do so, run the following command (NOTE: replace `conda` with `mamba` if you have it installed): + +```bash +conda env create -f environment.yml -n jmp +``` + +If the above command fails, you can create the environment manually yourself: + +```bash +# Create the environment +conda create -n jmp python=3.11 +conda activate jmp + +# Install PyTorch +conda install -y -c pytorch -c nvidia pytorch torchvision torchaudio pytorch-cuda=12.1 + +# Install PyG, PyTorch Scatter, PyTorch Sparse. +conda install -c pyg pyg pytorch-sparse pytorch-cluster + +# Install other conda dependencies +conda install -y \ + -c conda-forge \ + numpy matplotlib seaborn sympy pandas numba scikit-learn plotly nbformat ipykernel ipywidgets tqdm pyyaml networkx \ + pytorch-lightning torchmetrics lightning \ + einops wandb \ + cloudpickle \ + "pydantic>2" \ + frozendict wrapt varname typing-extensions lovely-tensors lovely-numpy requests pytest nbval + +# Install pip dependencies +pip install lmdb + +# Install dependencies for materials datasets +pip install ase + +# Install dependencies for large molecule datasets +conda install h5py + +# Install MatBench dependencies +pip install matbench + +# Install dependencies for PDBBind +pip install biopython rdkit + +# Install dependencies for pre-processing ANI1x/Transition1x +pip install multiprocess +``` + +Then, activate the environment: + +```bash +conda activate jmp +``` + +Finally, install the current package as follows: + +```bash +pip install -e . +``` + +The code is now ready to be used. See the configuration files in the `configs` directory for examples on how to run the code. + +## Datasets + +The data used for pre-training and fine-tuning is not included in this repository due to size constraints. However, instructions for downloading and preprocessing the OC20, OC22, ANI-1x, Transition-1x, QM9, rMD17, MatBench, QMOF, SPICE, and MD22 datasets are provided below. + +### Pre-training Datasets + +- For OC20 and OC22, please refer to the [Open Catalyst Project dataset instructions](https://github.com/Open-Catalyst-Project/ocp/blob/main/DATASET.md). +- For ANI-1x, refer to the "ANI-1x Dataset" section below. +- For Transition-1x, refer to the "Transition-1x Dataset" section below. + +#### ANI-1x Dataset + +To download the ANI-1x dataset and convert it to a format that can be used by our codebase, please follow these steps: + +1. Create a directory to store the ANI-1x dataset and navigate to it: + +```bash +mkdir -p /path/to/datasets/ani1x +cd /path/to/datasets/ani1x +``` + +2. Download the ANI-1x dataset (in `.h5` format) from the [official source](https://springernature.figshare.com/ndownloader/files/18112775): + +```bash +wget https://springernature.figshare.com/ndownloader/files/18112775 -O ani1x-release.h5 +``` + +3. Compute the train, validation, and test splits: + +```bash +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_splits --input_file ani1x-release.h5 --train_keys_output train_keys.pkl --val_keys_output val_keys.pkl --test_keys_output test_keys.pkl +``` + +4. Convert the h5 file to a set of `.traj` files (set `--num_workers` to the number of CPU cores you have available): + +```bash +mkdir -p traj +mkdir -p traj/train traj/val traj/test +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_write_traj --ani1x_h5 ani1x-release.h5 --split_keys train_keys.pkl --split train --traj_dir traj/train --num_workers 32 +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_write_traj --ani1x_h5 ani1x-release.h5 --split_keys val_keys.pkl --split val --traj_dir traj/val --num_workers 32 +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_write_traj --ani1x_h5 ani1x-release.h5 --split_keys test_keys.pkl --split test --traj_dir traj/test --num_workers 32 +``` + +5. Convert the `.traj` files to `.lmdb` files: + +```bash +mkdir -p lmdb +mkdir -p lmdb/train lmdb/val lmdb/test +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_write_lmdbs --data_path traj/train --out_path lmdb/train --split train --num_workers 32 +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_write_lmdbs --data_path traj/val --out_path lmdb/val --split val --num_workers 32 +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_write_lmdbs --data_path traj/test --out_path lmdb/test --split test --num_workers 32 +``` + +6. Compute the linear reference energies: + +```bash +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_linear_ref linref --src lmdb/train --out_path linref.npz +``` + +7. Compute the mean/std of the energies: + +```bash +python -m jmp.datasets.scripts.ani1x_preprocess.ani1x_linear_ref compute_mean_std --src lmdb/train --linref_path linref.npz --out_path mean_std.pkl +``` + +#### Transition-1x Dataset + +To download the Transition-1x dataset and convert it to a format that can be used by our codebase, please follow these steps: + +1. Create a directory to store the Transition-1x dataset and navigate to it: + +```bash +mkdir -p /path/to/datasets/transition1x +cd /path/to/datasets/transition1x +``` + +2. Download the Transition-1x dataset (in `.h5` format) from the [official source](https://figshare.com/ndownloader/files/36035789): + +```bash +wget https://figshare.com/ndownloader/files/36035789 -O transition1x-release.h5 +``` + +3. Convert the h5 file to a set of `.traj` files (set `--num_workers` to the number of CPU cores you have available): + +```bash +mkdir -p traj +mkdir -p traj/train traj/val traj/test +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_write_traj --transition1x_h5 transition1x-release.h5 --split train --traj_dir traj/train --num_workers 32 +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_write_traj --transition1x_h5 transition1x-release.h5 --split val --traj_dir traj/val --num_workers 32 +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_write_traj --transition1x_h5 transition1x-release.h5 --split test --traj_dir traj/test --num_workers 32 +``` + +4. Convert the `.traj` files to `.lmdb` files: + +```bash +mkdir -p lmdb +mkdir -p lmdb/train lmdb/val lmdb/test +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_write_lmdbs --data_path traj/train --out_path lmdb/train --split train --num_workers 32 +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_write_lmdbs --data_path traj/val --out_path lmdb/val --split val --num_workers 32 +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_write_lmdbs --data_path traj/test --out_path lmdb/test --split test --num_workers 32 +``` + +5. Compute the linear reference energies: + +```bash +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_linear_ref linref --src lmdb/train --out_path linref.npz +``` + +6. Compute the mean/std of the energies: + +```bash +python -m jmp.datasets.scripts.transition1x_preprocess.trans1x_linear_ref compute_mean_std --src lmdb/train --linref_path linref.npz --out_path mean_std.pkl +``` + +### Fine-tuning Datasets +- For rMD17, please run the following command to download the dataset to your desired directory: `python -m jmp.datasets.finetune.rmd17 download --destination /path/to/datasets/rmd17/` +- For QM9, please run the following command to download the dataset to your desired directory: `python -m jmp.datasets.finetune.qm9 download --destination /path/to/datasets/qm9/` +- For MD22, please run the following command to download the dataset to your desired directory: `python -m jmp.datasets.finetune.md22 download --destination /path/to/datasets/md22/` +- For SPICE, please run the following command to download the dataset to your desired directory: `python -m jmp.datasets.finetune.spice download --destination /path/to/datasets/spice/` +- For MatBench, please run the following command to download the dataset to your desired directory: `python -m jmp.datasets.finetune.mat_bench download --destination /path/to/datasets/matbench/` +- For QMOF, please run the following command to download the dataset to your desired directory: `python -m jmp.datasets.finetune.qmof download --destination /path/to/datasets/qmof/` + +## Pre-trained Checkpoints + +The pre-trained checkpoints are available for download from the following links: + +- [JMP-S](https://jmp-iclr-datasets.s3.amazonaws.com/jmp-s.pt) +- [JMP-L](https://jmp-iclr-datasets.s3.amazonaws.com/jmp-l.pt) + +## Training Models + +Our codebase is designed to be Python-based. To train a model, you create a configuration file that specifies the model, dataset, and training parameters. These configuration objects are fully type-checked and validated using Pydantic. + +Once you have created a configuration object, you can use the `jmp.lightning.Runner` class to wrap the training loop. See `configs/jmp_l_finetune.ipynb` for an example of how to finetune a model. + +## License + +The majority of JMP is CC-BY-NC licensed, as found in the `LICENSE` file. However, portions of the project are available under separate license terms: + +- The ASE library is licensed under the GNU Lesser General Public License v2.1. +- PyTorch Lightning and TorchMetrics are licensed under the Apache License 2.0. +- DeepChem is licensed under the MIT License. +- RDKit is licensed under the BSD 3-Clause License. +- Biopython is licensed under the Biopython License Agreement. +- Pydantic is licensed under the MIT License. +- MatBench is licensed under the MIT License. +- Submitit is licensed under the MIT License. +- Model implementations are based on the Open Catalyst Project and are licensed under the MIT License. +- EMA implementation is based on the NeMo's implementation and is licensed under the Apache License 2.0. + +## Citation + +If you use this code in your research, please cite the following paper: + +``` +@article{shoghi2023molecules, + title={From molecules to materials: Pre-training large generalizable models for atomic property prediction}, + author={Shoghi, Nima and Kolluru, Adeesh and Kitchin, John R and Ulissi, Zachary W and Zitnick, C Lawrence and Wood, Brandon M}, + journal={arXiv preprint arXiv:2310.16802}, + year={2023} +} +``` diff --git a/config/all-jmp-l-configs.py b/config/all-jmp-l-configs.py new file mode 100644 index 0000000..4700143 --- /dev/null +++ b/config/all-jmp-l-configs.py @@ -0,0 +1,602 @@ +# %% +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from collections.abc import Iterable +from pathlib import Path +from typing import Literal + +from jmp.lightning import Runner, Trainer +from jmp.models.gemnet.config import BackboneConfig +from jmp.modules.ema import EMA +from jmp.modules.transforms.normalize import NormalizationConfig as NC +from jmp.tasks.config import AdamWConfig +from jmp.tasks.finetune import dataset_config as DC +from jmp.tasks.finetune import matbench, md22, qm9, qmof, rmd17, spice +from jmp.tasks.finetune.base import ( + CheckpointBestConfig, + EarlyStoppingConfig, + FinetuneConfigBase, + FinetuneModelBase, + MulticlassClassificationTargetConfig, + PrimaryMetricConfig, + RLPConfig, + WarmupCosRLPConfig, +) +from jmp.utils.param_specific_util import make_parameter_specific_optimizer_config + +FinetuneConfigBase.set_seed(42) + +BASE_DATASET_PATH = Path("/mnt/shared/datasets") +SCALE_FILE_PATH = Path("/path/to/gemnet/scale_files/large.pt") +PRETRAINED_CKPT_PATH = Path("/path/to/pretrained/checkpoint.ckpt") + + +def config_( + config: FinetuneConfigBase, + *, + batch: int, + scalar: bool = False, +): + # Large model + config.backbone = BackboneConfig.large() + config.embedding.embedding_size = config.backbone.emb_size_atom + # config.ln = False + # config.backbone.ln = False + # config.backbone.replace_scale_factors_with_ln = False + config.backbone.scale_basis = False + + # Misc + config.meta["ckpt_path"] = str(PRETRAINED_CKPT_PATH.absolute()) + config.backbone.scale_file = str(SCALE_FILE_PATH.absolute()) + config.meta["ema_backbone"] = True + config.trainer.precision = "16-mixed" + + # Stopping criteria + if isinstance(config, rmd17.RMD17Config): + config.trainer.max_epochs = 100000 + config.trainer.max_time = "07:00:00:00" + config.early_stopping = EarlyStoppingConfig( + patience=1000, + min_delta=1.0e-8, + min_lr=1.0e-10, + ) + else: + config.trainer.max_epochs = 500 + config.trainer.max_time = "07:00:00:00" + config.early_stopping = EarlyStoppingConfig( + patience=50, + min_delta=1.0e-8, + min_lr=1.0e-8, + ) + # Checkpointing + config.ckpt_best = CheckpointBestConfig() + + # Training speedup by disabling some features + config.trainer.optimizer.log_grad_norm = True + config.trainer.optimizer.log_grad_norm_per_param = False + config.trainer.optimizer.log_param_norm = False + config.trainer.optimizer.log_param_norm_per_param = False + config.trainer.supports_parameter_hooks = False + config.trainer.supports_skip_batch_exception = False + config.trainer.logging.wandb.log_model = False + + # Optimizer + config.optimizer = AdamWConfig( + lr=8.0e-5, + amsgrad=False, + betas=(0.9, 0.95), + eps=1.0e-8, + weight_decay=0.1, + ) + config.trainer.gradient_clip_val = 1.0 + config.trainer.gradient_clip_algorithm = "value" + + # LR Scheduler + if isinstance(config, rmd17.RMD17Config): + config.lr_scheduler = WarmupCosRLPConfig( + warmup_epochs=5, + warmup_start_lr_factor=1.0e-1, + should_restart=False, + max_epochs=32, + min_lr_factor=0.1, + rlp=RLPConfig(mode="min", patience=10, factor=0.8), + ) + else: + config.lr_scheduler = WarmupCosRLPConfig( + warmup_epochs=5, + warmup_start_lr_factor=1.0e-1, + should_restart=False, + max_epochs=32, + min_lr_factor=0.1, + rlp=RLPConfig(mode="min", patience=3, factor=0.8), + ) + + config.parameter_specific_optimizers = make_parameter_specific_optimizer_config( + config, + config.backbone.num_blocks, + { + "embedding": 0.3, + "blocks_0": 0.55, + "blocks_1": 0.40, + "blocks_2": 0.30, + "blocks_3": 0.40, + "blocks_4": 0.55, + "blocks_5": 0.625, + }, + ) + + # Passed args + config.project = "4_11_ft_lg_jmp_testing" + config.batch_size = batch + if scalar: + config.backbone.regress_forces = False + config.backbone.direct_forces = False + + +MD17_STATS: dict[rmd17.RMD17Molecule, dict[str, NC]] = { + "aspirin": { + "y": NC(mean=17617.379355234374, std=0.2673998440577667), + "force": NC(mean=0.0, std=1.2733363), + }, + "azobenzene": { + "y": NC(mean=15553.118351233397, std=0.2866098335926971), + "force": NC(mean=0.0, std=1.2940075), + }, + "benzene": { + "y": NC(mean=6306.374855859375, std=0.10482645661015047), + "force": NC(mean=0.0, std=0.90774584), + }, + "ethanol": { + "y": NC(mean=4209.534573266602, std=0.18616576961275716), + "force": NC(mean=0.0, std=1.1929188), + }, + "malonaldehyde": { + "y": NC(mean=7254.903633896484, std=0.1812291921138577), + "force": NC(mean=0.0, std=1.302443), + }, + "naphthalene": { + "y": NC(mean=10478.192319667969, std=0.24922674853668708), + "force": NC(mean=0.0, std=1.3102233), + }, + "paracetamol": { + "y": NC(mean=13998.780924130859, std=0.26963984094801224), + "force": NC(mean=0.0, std=1.2707518), + }, + "salicylic": { + "y": NC(mean=13472.110348867187, std=0.2437920552529055), + "force": NC(mean=0.0, std=1.3030343), + }, + "toluene": { + "y": NC(mean=7373.347077485351, std=0.22534282741069667), + "force": NC(mean=0.0, std=1.246547), + }, + "uracil": { + "y": NC(mean=11266.351949697266, std=0.2227113171300836), + "force": NC(mean=0.0, std=1.3692871), + }, +} + + +def md17_configs(is_grad: bool = True): + for molecule in MD17_STATS: + config = rmd17.RMD17Config.draft() + config.molecule = molecule + + config.name = f"md17_{molecule}" + config.model_type = "forces" + if is_grad: + config.gradient_forces = True + config.trainer.inference_mode = False + config.name += "_grad" + + config_(config, batch=1, scalar=is_grad) + config.trainer.precision = "32-true" + + config.normalization = MD17_STATS[molecule] + config.primary_metric = PrimaryMetricConfig(name="force_mae", mode="min") + + # Dataset + base_path = BASE_DATASET_PATH / "rmd17" + config.train_dataset = DC.rmd17_config(molecule, base_path, "train") + config.val_dataset = DC.rmd17_config(molecule, base_path, "val") + config.test_dataset = DC.rmd17_config(molecule, base_path, "test") + + yield config, rmd17.RMD17Model + + +MD22_NORM: dict[md22.MD22Molecule, dict[str, NC]] = { + "Ac-Ala3-NHMe": { + "y": NC(mean=26913.953, std=0.35547638), + "force": NC(mean=1.4777572e-11, std=1.1291506), + }, + "DHA": { + "y": NC(mean=27383.035, std=0.41342595), + "force": NC(mean=5.5828797e-10, std=1.1258113), + }, + "stachyose": { + "y": NC(mean=68463.59, std=0.5940788), + "force": NC(mean=4.9331733e-10, std=1.1104717), + }, + "AT-AT": { + "y": NC(mean=50080.08, std=0.47309175), + "force": NC(mean=1.3477714e-09, std=1.2109985), + }, + "AT-AT-CG-CG": { + "y": NC(mean=101034.23, std=0.680055), + "force": NC(mean=3.476294e-10, std=1.2021886), + }, + "buckyball-catcher": { + "y": NC(mean=124776.7, std=0.64662045), + "force": NC(mean=6.8671324e-10, std=1.0899031), + }, + "double-walled_nanotube": { + "y": NC(mean=338224.16, std=3.3810701), + "force": NC(mean=7.239396e-11, std=1.0137014), + }, +} + +MD22_MOLECULES: list[tuple[md22.MD22Molecule, int, bool, bool]] = [ + ("Ac-Ala3-NHMe", 2, False, True), + ("DHA", 1, True, False), + ("stachyose", 1, True, True), + ("stachyose", 1, True, True), + ("AT-AT", 1, True, True), + ("AT-AT-CG-CG", 1, True, True), + ("buckyball-catcher", 1, True, True), + ("double-walled_nanotube", 1, False, True), +] + + +def md22_configs(): + for molecule, bsz, is_grad, amp in MD22_MOLECULES: + config = md22.MD22Config.draft() + config.molecule = molecule + + config.name = f"md22_{molecule}" + config.model_type = "forces" + if is_grad: + config.gradient_forces = True + config.trainer.inference_mode = False + config.name += "_grad" + else: + config.name += "_direct" + + config_(config, batch=bsz, scalar=is_grad) + + config.normalization = MD22_NORM[molecule] + config.primary_metric = PrimaryMetricConfig(name="force_mae", mode="min") + + base_path = BASE_DATASET_PATH / "md22" + config.train_dataset = DC.md22_config(molecule, base_path, "train") + config.val_dataset = DC.md22_config(molecule, base_path, "val") + config.test_dataset = DC.md22_config(molecule, base_path, "test") + + if amp: + config.trainer.precision = "16-mixed" + else: + config.trainer.precision = "32-true" + + yield config, md22.MD22Model + + +QM9_NORMALIZATION: dict[str, NC] = { + "mu": NC(mean=2.674587, std=1.5054824), + "alpha": NC(mean=75.31013, std=8.164021), + "eps_HMO": NC(mean=-6.5347567, std=0.59702325), + "eps_LUMO": NC(mean=0.323833, std=1.273586), + "delta_eps": NC(mean=6.8585854, std=1.283122), + "R_2_Abs": NC(mean=1189.6819, std=280.0421), + "ZPVE": NC(mean=-0.00052343315, std=0.04904531), + "U_0": NC(mean=0.0028667436, std=1.0965848), + "U": NC(mean=0.0028711546, std=1.0941933), + "H": NC(mean=0.0029801112, std=1.0942822), + "G": NC(mean=0.000976671, std=1.101572), + "c_v": NC(mean=-0.005799451, std=2.2179737), +} +QM9_TARGETS = [ + "mu", + "alpha", + "eps_HMO", + "eps_LUMO", + "delta_eps", + "R_2_Abs", + "ZPVE", + "U_0", + "U", + "H", + "G", + "c_v", +] +QM9_REDUCTION: dict[str, Literal["sum", "mean", "max"]] = { + "mu": "sum", + "alpha": "sum", + "eps_HMO": "sum", + "eps_LUMO": "sum", + "delta_eps": "sum", + "R_2_Abs": "sum", + "ZPVE": "sum", + "U_0": "sum", + "U": "sum", + "H": "sum", + "G": "sum", + "c_v": "sum", +} + + +def qm9_configs(): + for target in QM9_TARGETS: + config = qm9.QM9Config.draft() + config.graph_scalar_targets = [target] + config.name = f"qm9_{target}" + config.normalization = QM9_NORMALIZATION + # config.use_scalar_head_for_all_targets = True + config.graph_scalar_reduction = QM9_REDUCTION + config_(config, batch=48, scalar=True) + config.primary_metric = PrimaryMetricConfig(name=f"{target}_mae", mode="min") + + if target == "R_2_Abs": + config.output_head = qm9.SpatialExtentConfig() + + # Also, we don't use any normalization for this target + config.normalization = {} + else: + config.output_head = qm9.DefaultOutputHeadConfig() + + base_path = BASE_DATASET_PATH / "qm9" + config.train_dataset = DC.qm9_config(base_path, "train") + config.val_dataset = DC.qm9_config(base_path, "val") + config.test_dataset = DC.qm9_config(base_path, "test") + + yield config, qm9.QM9Model + + +QMOF_NORMALIZATION: dict[str, NC] = {"y": NC(mean=2.1866251527, std=1.175752521125648)} + + +def qmof_configs(): + config = qmof.QMOFConfig.draft() + + config.name = "qmof" + config.graph_scalar_reduction_default = "mean" + config.normalization = QMOF_NORMALIZATION + config_(config, batch=4, scalar=True) + + config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min") + + base_path = BASE_DATASET_PATH / "qmof" + config.train_dataset = DC.qmof_config(base_path, "train") + config.val_dataset = DC.qmof_config(base_path, "val") + config.test_dataset = DC.qmof_config(base_path, "test") + + yield config, qmof.QMOFModel + + +SPICE_DATASETS: list[spice.SPICEDataset] = [ + "solvated_amino_acids", + "dipeptides", +] +SPICE_NORMALIZATION: dict[str, dict[str, NC]] = { + "dipeptides": { + "y": NC(mean=-31213.615, std=4636.815), + "force": NC(mean=3.3810358e-07, std=0.5386545), + }, + "solvated_amino_acids": { + "y": NC(mean=-60673.68, std=3310.6692), + "force": NC(mean=2.7950014e-07, std=0.81945145), + }, +} + + +def spice_configs(is_grad: bool = True): + for dataset in SPICE_DATASETS: + config = spice.SPICEConfig.draft() + + config.dataset = dataset + config.name = f"spice_{dataset}" + config.model_type = "forces" + if is_grad: + config.gradient_forces = True + config.trainer.inference_mode = False + config.name += "_grad" + + config_(config, batch=2, scalar=is_grad) + config.primary_metric = PrimaryMetricConfig(name="force_mae", mode="min") + + config.normalization = SPICE_NORMALIZATION[dataset] + + base_path = BASE_DATASET_PATH / "spice" + config.train_dataset = DC.spice_config(dataset, base_path, "train") + config.val_dataset = DC.spice_config(dataset, base_path, "val") + config.test_dataset = DC.spice_config(dataset, base_path, "test") + + yield config, spice.SPICEModel + + +MATBENCH_DATASETS: list[tuple[matbench.MatbenchDataset, int, bool]] = [ + ("mp_is_metal", 3, False), + ("jdft2d", 3, False), + ("phonons", 8, True), + ("dielectric", 8, True), + ("log_gvrh", 8, True), + ("log_kvrh", 8, True), + ("perovskites", 8, True), + ("mp_gap", 2, False), + ("mp_e_form", 6, True), +] +MATBENCH_NORMALIZATION: dict[str, dict[str, NC]] = { + "jdft2d_fold0": {"y": NC(mean=110.63706001904778, std=132.02502987887982)}, + "jdft2d_fold1": {"y": NC(mean=100.05996525195053, std=114.26362221432791)}, + "jdft2d_fold2": {"y": NC(mean=101.59535193788061, std=112.45760038504558)}, + "jdft2d_fold3": {"y": NC(mean=99.43551549230911, std=109.9220303290942)}, + "jdft2d_fold4": {"y": NC(mean=95.50851385805468, std=76.27587565670332)}, + "phonons_fold0": {"y": NC(mean=602.9007780432183, std=471.03858838413055)}, + "phonons_fold1": {"y": NC(mean=613.7996473907606, std=486.75099875453213)}, + "phonons_fold2": {"y": NC(mean=619.3868976573087, std=495.2975486965762)}, + "phonons_fold3": {"y": NC(mean=609.7402387661577, std=462.3438660855412)}, + "phonons_fold4": {"y": NC(mean=595.4547502676089, std=476.8567310885976)}, + "dielectric_fold0": {"y": NC(mean=2.417849270334958, std=2.208662738016193)}, + "dielectric_fold1": {"y": NC(mean=2.3716402963883074, std=2.1271523121706912)}, + "dielectric_fold2": {"y": NC(mean=2.354418196731436, std=1.5712251872961516)}, + "dielectric_fold3": {"y": NC(mean=2.392308273978868, std=2.0724149898647544)}, + "dielectric_fold4": {"y": NC(mean=2.3891527750974495, std=2.011348533899877)}, + "log_gvrh_fold0": {"y": NC(mean=1.5557434474198688, std=0.37307197984408746)}, + "log_gvrh_fold1": {"y": NC(mean=1.5584101768747889, std=0.36743473539736493)}, + "log_gvrh_fold2": {"y": NC(mean=1.55746252819908, std=0.36800038945046654)}, + "log_gvrh_fold3": {"y": NC(mean=1.5543022349873286, std=0.3684552493569905)}, + "log_gvrh_fold4": {"y": NC(mean=1.5595705795473838, std=0.37039750391284176)}, + "log_kvrh_fold0": {"y": NC(mean=1.880001033036957, std=0.36820395518377785)}, + "log_kvrh_fold1": {"y": NC(mean=1.883820392919235, std=0.3679308395031994)}, + "log_kvrh_fold2": {"y": NC(mean=1.883778380784775, std=0.3724392829717956)}, + "log_kvrh_fold3": {"y": NC(mean=1.8828457515367547, std=0.3731179944882516)}, + "log_kvrh_fold4": {"y": NC(mean=1.8862681006404232, std=0.3671596024523317)}, + "perovskites_fold0": {"y": NC(mean=1.4726657310327749, std=0.7384309800882398)}, + "perovskites_fold1": {"y": NC(mean=1.4690728968876414, std=0.736635027626099)}, + "perovskites_fold2": {"y": NC(mean=1.4702980269132337, std=0.7456716470700677)}, + "perovskites_fold3": {"y": NC(mean=1.46773815420175, std=0.7431740904365189)}, + "perovskites_fold4": {"y": NC(mean=1.478002311375268, std=0.7435117654840315)}, + "mp_gap_fold0": {"y": NC(mean=1.236432240252091, std=1.6096424425437108)}, + "mp_gap_fold1": {"y": NC(mean=1.2345678083402052, std=1.6044708412420103)}, + "mp_gap_fold2": {"y": NC(mean=1.2352391374131229, std=1.6058092380465256)}, + "mp_gap_fold3": {"y": NC(mean=1.230066812934386, std=1.6003749533498033)}, + "mp_gap_fold4": {"y": NC(mean=1.2350543114618917, std=1.6035590734723943)}, + "mp_e_form_fold0": {"y": NC(mean=-1.4371594843879998, std=1.1577096884761835)}, + "mp_e_form_fold1": {"y": NC(mean=-1.4372781184639032, std=1.1576872656463288)}, + "mp_e_form_fold2": {"y": NC(mean=-1.4353308741245294, std=1.1568986659292604)}, + "mp_e_form_fold3": {"y": NC(mean=-1.4337824626302396, std=1.1570679204976484)}, + "mp_e_form_fold4": {"y": NC(mean=-1.437067044514929, std=1.1567267481888575)}, +} + + +def matbench_configs(folds: Iterable[matbench.MatbenchFold]): + for dataset, bsz, amp in MATBENCH_DATASETS: + for fold in folds: + config = matbench.MatbenchConfig.draft() + config.dataset = dataset + match dataset: + case "phonons": + config.graph_scalar_reduction_default = "max" + case "mp_is_metal": + config.graph_scalar_targets = [] + config.graph_classification_targets = [] + config.graph_classification_targets.append( + MulticlassClassificationTargetConfig( + name="y", + num_classes=2, + class_weights=[1.0, 1.34219693], + dropout=0.5, + ) + ) + config.node_vector_targets = [] + case _: + config.graph_scalar_reduction_default = "mean" + + config.fold = fold + config.name = f"matbench_{dataset}_fold{fold}" + config.mp_e_form_dev = False + + config.normalization = MATBENCH_NORMALIZATION.get( + f"{dataset}_fold{fold}", {} + ) + config.conditional_max_neighbors = True + config_(config, batch=bsz, scalar=True) + + if dataset == "mp_is_metal": + config.primary_metric = PrimaryMetricConfig( + name="y_balanced_accuracy", mode="max" + ) + else: + config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min") + + if amp: + config.trainer.precision = "16-mixed" + else: + config.trainer.precision = "32-true" + + base_path = BASE_DATASET_PATH / "matbench" + config.train_dataset = DC.matbench_config(dataset, base_path, "train", fold) + config.val_dataset = DC.matbench_config(dataset, base_path, "val", fold) + config.test_dataset = DC.matbench_config(dataset, base_path, "test", fold) + + yield config, matbench.MatbenchModel + + +def all_runs(): + yield from matbench_configs([0, 1, 2, 3, 4]) + yield from md17_configs() + yield from md22_configs() + yield from qm9_configs() + yield from qmof_configs() + yield from spice_configs() + + +configs: list[tuple[FinetuneConfigBase, type[FinetuneModelBase]]] = [] +for base_config, model_cls in all_runs(): + config = copy.deepcopy(base_config) + config.id = FinetuneConfigBase.generate_id() + + config.trainer.logging.wandb.log_model = False + config = config.finalize() + configs.append((config, model_cls)) + +for config, _ in configs: + assert config.backbone.scale_file, f"Scale file not set for {config.name}" + +print("\n".join([c.name for c, _ in configs])) +print(len(configs)) + + +# %% +from jmp.lightning import Runner, Trainer +from jmp.modules.ema import EMA +from jmp.utils.finetune_state_dict import ( + filter_state_dict, + retreive_state_dict_for_finetuning, +) + + +def run(config: FinetuneConfigBase, model_cls: type[FinetuneModelBase]): + model = model_cls(config) + + if (ckpt_path := config.meta.get("ckpt_path")) is None: + raise ValueError("ckpt_path must be provided") + + state_dict = retreive_state_dict_for_finetuning( + ckpt_path, + load_emas=config.meta.get("ema_backbone", False), + ) + embedding = filter_state_dict(state_dict, "embedding.atom_embedding.") + backbone = filter_state_dict(state_dict, "backbone.") + + model.load_backbone_state_dict(backbone=backbone, embedding=embedding, strict=False) + + callbacks = [] + if (ema := config.meta.get("ema")) is not None: + ema = EMA(decay=ema) + callbacks.append(ema) + + trainer = Trainer(config, callbacks=callbacks) + trainer.fit(model) + + +# %% +# runner = Runner(run) +# runner.fast_dev_run(configs) + +# %% +runner = Runner(run) +jobs = runner.submit( + configs, + gpus=1, + nodes=1, + cpus_per_task=configs[0][0].num_workers + 1, + partition="ocp", + constraint="volta32gb", + snapshot=True, +) diff --git a/config/jmp_l_finetune.ipynb b/config/jmp_l_finetune.ipynb new file mode 100644 index 0000000..d6c83ff --- /dev/null +++ b/config/jmp_l_finetune.ipynb @@ -0,0 +1,335 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No normalization for SPS. Feature removed!\n", + "No normalization for AvgIpc. Feature removed!\n", + "Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'\n", + "Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'\n", + "Skipped loading modules with transformers dependency. No module named 'transformers'\n", + "cannot import name 'HuggingFaceModel' from 'deepchem.models.torch_models' (/opt/conda/envs/jmp/lib/python3.11/site-packages/deepchem/models/torch_models/__init__.py)\n", + "Skipped loading some Jax models, missing a dependency. No module named 'jax'\n", + "Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "id='mpo5imfg' trainer=TrainerConfig(optimizer=OptimizerConfig(log_grad_norm=True, gradient_clipping=GradientClippingConfig(value=1.0, algorithm='value')), supports_skip_batch_exception=False, supports_parameter_hooks=False, set_float32_matmul_precision='medium', precision='16-mixed', max_epochs=100000, max_time='07:00:00:00', inference_mode=False) meta={'ckpt_path': PosixPath('/mnt/shared/checkpoints/fm_gnoc_large_2_epoch.ckpt'), 'ema_backbone': True} train_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/train'), metadata_path=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/train/metadata.npz')) val_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/val'), metadata_path=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/val/metadata.npz')) test_dataset=FinetuneLmdbDatasetConfig(src=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/test'), metadata_path=PosixPath('/mnt/shared/datasets/rmd17/lmdb/aspirin/test/metadata.npz')) optimizer=AdamWConfig(lr=5e-06, weight_decay=0.1, betas=(0.9, 0.95)) lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.1, rlp=RLPConfig(patience=25, factor=0.8)) backbone=BackboneConfig(num_spherical=7, num_radial=128, num_blocks=6, emb_size_atom=256, emb_size_edge=1024, emb_size_trip_in=64, emb_size_trip_out=128, emb_size_quad_in=64, emb_size_quad_out=32, emb_size_aint_in=64, emb_size_aint_out=64, emb_size_rbf=32, emb_size_cbf=16, emb_size_sbf=64, num_before_skip=2, num_after_skip=2, num_concat=4, num_atom=3, num_output_afteratom=3, num_atom_emb_layers=2, regress_forces=False, sbf={'name': 'legendre_outer'}, quad_interaction=True, atom_edge_interaction=True, edge_atom_interaction=True, atom_interaction=True, qint_tags=[1, 2], absolute_rbf_cutoff=12.0, dropout=None, edge_dropout=None) batch_size=4 primary_metric=PrimaryMetricConfig(name='force_mae', mode='min') early_stopping=EarlyStoppingConfig(patience=1000, min_lr=1e-10) normalization={'y': NormalizationConfig(mean=-17617.379355234374, std=0.2673998440577667), 'force': NormalizationConfig(mean=0.0, std=1.2733363)} parameter_specific_optimizers=[ParamSpecificOptimizerConfig(paremeter_patterns=['embedding.*'], optimizer=AdamWConfig(lr=1.5e-06, weight_decay=0.1, betas=(0.9, 0.95)), lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.33333333333333337, rlp=RLPConfig(patience=3, factor=0.8))), ParamSpecificOptimizerConfig(paremeter_patterns=['backbone.int_blocks.0.*', 'backbone.out_blocks.1.*', 'backbone.out_blocks.0.*'], optimizer=AdamWConfig(lr=2.7500000000000004e-06, weight_decay=0.1, betas=(0.9, 0.95)), lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.18181818181818182, rlp=RLPConfig(patience=3, factor=0.8))), ParamSpecificOptimizerConfig(paremeter_patterns=['backbone.int_blocks.1.*', 'backbone.out_blocks.2.*'], optimizer=AdamWConfig(lr=2.0000000000000003e-06, weight_decay=0.1, betas=(0.9, 0.95)), lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.25, rlp=RLPConfig(patience=3, factor=0.8))), ParamSpecificOptimizerConfig(paremeter_patterns=['backbone.int_blocks.2.*', 'backbone.out_blocks.3.*'], optimizer=AdamWConfig(lr=1.5e-06, weight_decay=0.1, betas=(0.9, 0.95)), lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.33333333333333337, rlp=RLPConfig(patience=3, factor=0.8))), ParamSpecificOptimizerConfig(paremeter_patterns=['backbone.int_blocks.3.*', 'backbone.out_blocks.4.*'], optimizer=AdamWConfig(lr=2.0000000000000003e-06, weight_decay=0.1, betas=(0.9, 0.95)), lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.25, rlp=RLPConfig(patience=3, factor=0.8))), ParamSpecificOptimizerConfig(paremeter_patterns=['backbone.int_blocks.4.*', 'backbone.out_blocks.5.*'], optimizer=AdamWConfig(lr=2.7500000000000004e-06, weight_decay=0.1, betas=(0.9, 0.95)), lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.18181818181818182, rlp=RLPConfig(patience=3, factor=0.8))), ParamSpecificOptimizerConfig(paremeter_patterns=['backbone.int_blocks.5.*', 'backbone.out_blocks.6.*'], optimizer=AdamWConfig(lr=3.125e-06, weight_decay=0.1, betas=(0.9, 0.95)), lr_scheduler=WarmupCosRLPConfig(warmup_epochs=5, max_epochs=32, warmup_start_lr_factor=0.1, min_lr_factor=0.16, rlp=RLPConfig(patience=3, factor=0.8)))] gradient_forces=True model_type='forces' molecule='aspirin'\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/repositories/fm/src/jmp/lightning/model/config.py:709: BaseConfig._rng is None. The generated IDs will not be reproducible. To fix this, call BaseConfig.set_seed(...) before generating any IDs.\n" + ] + } + ], + "source": [ + "\"\"\"\n", + "Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "All rights reserved.\n", + "\n", + "This source code is licensed under the license found in the\n", + "LICENSE file in the root directory of this source tree.\n", + "\"\"\"\n", + "\n", + "from pathlib import Path\n", + "\n", + "from jmp.configs.finetune.jmp_l import jmp_l_ft_config_\n", + "from jmp.configs.finetune.rmd17 import jmp_l_rmd17_config_\n", + "from jmp.tasks.finetune.base import FinetuneConfigBase, FinetuneModelBase\n", + "from jmp.tasks.finetune.rmd17 import RMD17Config, RMD17Model\n", + "\n", + "ckpt_path = Path(\"/mnt/shared/checkpoints/fm_gnoc_large_2_epoch.ckpt\")\n", + "base_path = Path(\"/mnt/shared/datasets/rmd17/\")\n", + "\n", + "# We create a list of all configurations that we want to run.\n", + "configs: list[tuple[FinetuneConfigBase, type[FinetuneModelBase]]] = []\n", + "\n", + "config = RMD17Config.draft()\n", + "jmp_l_ft_config_(config, ckpt_path) # This loads the base JMP-L fine-tuning config\n", + "# This loads the rMD17-specific configuration\n", + "jmp_l_rmd17_config_(config, \"aspirin\", base_path)\n", + "config = config.finalize() # Actually construct the config object\n", + "print(config)\n", + "\n", + "configs.append((config, RMD17Model))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a6a4164e0219458c8b37bde34b2ae030", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fast dev run: 0%| | 0/1 [00:00]}.\n", + "Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", + "Using 16bit Automatic Mixed Precision (AMP)\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/opt/conda/envs/jmp/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.\n", + "WARNING:jmp.lightning.trainer.logging:Logger DummyLogger does not support run_id, ignoring.\n", + "CRITICAL:jmp.lightning.trainer.trainer:LightningTrainer log directory: None.\n", + "/opt/conda/envs/jmp/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:126: `.fit(ckpt_path=None)` was called without a model. The last model of the previous `fit` call will be used. You can pass `fit(ckpt_path='best')` to use the best model or `fit(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n", + "/opt/conda/envs/jmp/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path=\"last\") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.\n", + "WARNING:jmp.lightning.model.modules.wandb:Could not find wandb logger or module to log\n", + "CRITICAL:jmp.lightning.model.base:Fast dev run detected, setting debug flag to True.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "CRITICAL:jmp.tasks.config:Optimizer: AdamW\n", + "Optimizer kwargs: {}\n", + "Base kwargs: {'lr': 5e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08}\n", + "Param groups: Param group 0:\n", + " Params: 1\n", + " Total param size: 30720\n", + " Other kwargs: {'lr': 1.5e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'embedding.*', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Param group 1:\n", + " Params: 85\n", + " Total param size: 27354112\n", + " Other kwargs: {'lr': 2.7500000000000004e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'backbone.int_blocks.0.*,backbone.out_blocks.1.*,backbone.out_blocks.0.*', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Param group 2:\n", + " Params: 71\n", + " Total param size: 26272768\n", + " Other kwargs: {'lr': 2.0000000000000003e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'backbone.int_blocks.1.*,backbone.out_blocks.2.*', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Param group 3:\n", + " Params: 71\n", + " Total param size: 26272768\n", + " Other kwargs: {'lr': 1.5e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'backbone.int_blocks.2.*,backbone.out_blocks.3.*', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Param group 4:\n", + " Params: 71\n", + " Total param size: 26272768\n", + " Other kwargs: {'lr': 2.0000000000000003e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'backbone.int_blocks.3.*,backbone.out_blocks.4.*', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Param group 5:\n", + " Params: 71\n", + " Total param size: 26272768\n", + " Other kwargs: {'lr': 2.7500000000000004e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'backbone.int_blocks.4.*,backbone.out_blocks.5.*', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Param group 6:\n", + " Params: 71\n", + " Total param size: 26272768\n", + " Other kwargs: {'lr': 3.125e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'backbone.int_blocks.5.*,backbone.out_blocks.6.*', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Param group 7:\n", + " Params: 23\n", + " Total param size: 2126080\n", + " Other kwargs: {'lr': 5e-06, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'name': 'rest', 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:Computed warmup_steps: 5\n", + "CRITICAL:jmp.tasks.finetune.base:Computed max_steps: 32\n", + "CRITICAL:jmp.tasks.finetune.base:param_group_lr_scheduler_settings=[{'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 1.5000000000000002e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}, {'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 2.7500000000000007e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}, {'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 2.0000000000000004e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}, {'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 1.5000000000000002e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}, {'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 2.0000000000000004e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}, {'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 2.7500000000000007e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}, {'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 3.125e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}, {'warmup_epochs': 5, 'max_epochs': 32, 'warmup_start_lr': 5.000000000000001e-07, 'eta_min': 5.000000000000001e-07, 'should_restart': False}]\n", + "/opt/conda/envs/jmp/lib/python3.11/site-packages/torch/optim/lr_scheduler.py:28: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n", + " warnings.warn(\"The verbose parameter is deprecated. Please use get_last_lr() \"\n", + "\n", + " | Name | Type | Params\n", + "---------------------------------------------------\n", + "0 | embedding | Embedding | 30.7 K\n", + "1 | backbone | GemNetOCBackbone | 160 M \n", + "2 | out_energy | Sequential | 262 K \n", + "3 | train_metrics | FinetuneMetrics | 0 \n", + "4 | val_metrics | FinetuneMetrics | 0 \n", + "5 | test_metrics | FinetuneMetrics | 0 \n", + "---------------------------------------------------\n", + "160 M Trainable params\n", + "0 Non-trainable params\n", + "160 M Total params\n", + "643.499 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0677ceeb641846debdd9d821699a5611", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00 None:\n", + " if (ckpt_path := config.meta.get(\"ckpt_path\")) is None:\n", + " raise ValueError(\"No checkpoint path provided\")\n", + "\n", + " model = model_cls(config)\n", + "\n", + " # Load the checkpoint\n", + " state_dict = retreive_state_dict_for_finetuning(\n", + " ckpt_path, load_emas=config.meta.get(\"ema_backbone\", False)\n", + " )\n", + " embedding = filter_state_dict(state_dict, \"embedding.atom_embedding.\")\n", + " backbone = filter_state_dict(state_dict, \"backbone.\")\n", + " model.load_backbone_state_dict(backbone=backbone, embedding=embedding, strict=True)\n", + "\n", + " trainer = Trainer(config)\n", + " trainer.fit(model)\n", + "\n", + "\n", + "runner = Runner(run)\n", + "runner.fast_dev_run(configs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/config/jmp_l_pretrain.ipynb b/config/jmp_l_pretrain.ipynb new file mode 100644 index 0000000..ff15680 --- /dev/null +++ b/config/jmp_l_pretrain.ipynb @@ -0,0 +1,308 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "id='fy0th22x' trainer=TrainerConfig(optimizer=OptimizerConfig(log_grad_norm=True, gradient_clipping=GradientClippingConfig(value=1.0)), supports_skip_batch_exception=False, supports_parameter_hooks=False, set_float32_matmul_precision='medium', precision='16-mixed', use_distributed_sampler=False) optimizer=AdamWConfig(lr=0.0003, weight_decay=0.1, betas=(0.9, 0.95)) lr_scheduler=LinearWarmupCosineAnnealingSchedulerConfig(warmup_steps=2000, max_epochs=2, warmup_start_lr_factor=0.2, min_lr_factor=0.1) edge_dropout=0.1 backbone=BackboneConfig(num_spherical=7, num_radial=128, num_blocks=4, emb_size_atom=256, emb_size_edge=512, emb_size_trip_in=64, emb_size_trip_out=64, emb_size_quad_in=32, emb_size_quad_out=32, emb_size_aint_in=64, emb_size_aint_out=64, emb_size_rbf=16, emb_size_cbf=16, emb_size_sbf=32, num_before_skip=2, num_after_skip=2, num_concat=1, num_atom=3, num_output_afteratom=3, num_atom_emb_layers=2, direct_forces=True, sbf={'name': 'legendre_outer'}, quad_interaction=True, atom_edge_interaction=True, edge_atom_interaction=True, atom_interaction=True, qint_tags=[1, 2], absolute_rbf_cutoff=12.0, dropout=None, edge_dropout=0.1) batch_size=4 num_workers=8 tasks=[TaskConfig(name='oc20', train_dataset=PretrainDatasetConfig(src=PosixPath('/datasets/s2ef/2M/train'), metadata_path=PosixPath('/datasets/s2ef/2M/train_metadata.npz')), val_dataset=PretrainDatasetConfig(src=PosixPath('/datasets/s2ef/all/val_id'), metadata_path=PosixPath('/datasets/s2ef/all/val_id_metadata.npz')), force_loss_scale=73.0, normalization={'y': NormalizationConfig(mean=0.0, std=24.901469505465872), 'force': NormalizationConfig(mean=0.0, std=0.5111534595489502)}), TaskConfig(name='oc22', train_dataset=PretrainDatasetConfig(src=PosixPath('/shared/pre-training-datasets/oc22/s2ef-total/train'), metadata_path=PosixPath('/shared/pre-training-datasets/oc22/s2ef-total/train/metadata.npz')), val_dataset=PretrainDatasetConfig(src=PosixPath('/shared/pre-training-datasets/oc22/s2ef-total/val_id'), metadata_path=PosixPath('/shared/pre-training-datasets/oc22/s2ef-total/val_id/metadata.npz')), force_loss_scale=80.0, normalization={'y': NormalizationConfig(mean=0.0, std=25.229595396538468), 'force': NormalizationConfig(mean=0.0, std=0.25678861141204834)}), TaskConfig(name='ani1x', train_dataset=PretrainDatasetConfig(src=PosixPath('/shared/pre-training-datasets/ani1x/train'), metadata_path=PosixPath('/shared/pre-training-datasets/ani1x/train/metadata.npz')), val_dataset=PretrainDatasetConfig(src=PosixPath('/shared/pre-training-datasets/ani1x/val'), metadata_path=PosixPath('/shared/pre-training-datasets/ani1x/val/metadata.npz')), force_loss_scale=15.0, normalization={'y': NormalizationConfig(mean=0.0, std=2.8700712783472118), 'force': NormalizationConfig(mean=0.0, std=2.131422996520996)}), TaskConfig(name='transition1x', train_dataset=PretrainDatasetConfig(src=PosixPath('/shared/pre-training-datasets/trans1x/train'), metadata_path=PosixPath('/shared/pre-training-datasets/trans1x/train/metadata.npz')), val_dataset=PretrainDatasetConfig(src=PosixPath('/shared/pre-training-datasets/trans1x/val'), metadata_path=PosixPath('/shared/pre-training-datasets/trans1x/val/metadata.npz')), force_loss_scale=14.0, normalization={'y': NormalizationConfig(mean=0.0, std=1.787466168382901), 'force': NormalizationConfig(mean=0.0, std=0.3591422140598297)})] mt_dataset=MTDatasetConfig(sample_type='temperature', sample_temperature=2.0) ema=EMAConfig(decay=0.99)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/repositories/fm/src/ll/model/config.py:709: IdSeedWarning: BaseConfig._rng is None. The generated IDs will not be reproducible. To fix this, call BaseConfig.set_seed(...) before generating any IDs.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "\"\"\"\n", + "Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "All rights reserved.\n", + "\n", + "This source code is licensed under the license found in the\n", + "LICENSE file in the root directory of this source tree.\n", + "\"\"\"\n", + "\n", + "from pathlib import Path\n", + "\n", + "from jmp.configs.pretrain.jmp_l import jmp_l_pt_config_\n", + "from jmp.tasks.pretrain import PretrainConfig, PretrainModel\n", + "from jmp.tasks.pretrain.module import (\n", + " NormalizationConfig,\n", + " PretrainDatasetConfig,\n", + " TaskConfig,\n", + ")\n", + "\n", + "\n", + "# Let's make the config\n", + "def jmp_l_config():\n", + " config = PretrainConfig.draft()\n", + "\n", + " jmp_l_pt_config_(config)\n", + "\n", + " # Set data config\n", + " config.batch_size = 4\n", + " config.num_workers = 8\n", + "\n", + " # Set the tasks\n", + " config.tasks = [\n", + " TaskConfig(\n", + " name=\"oc20\",\n", + " train_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/datasets/s2ef/2M/train/\"),\n", + " metadata_path=Path(\"/datasets/s2ef/2M/train_metadata.npz\"),\n", + " ),\n", + " val_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/datasets/s2ef/all/val_id/\"),\n", + " metadata_path=Path(\"/datasets/s2ef/all/val_id_metadata.npz\"),\n", + " ),\n", + " energy_loss_scale=1.0,\n", + " force_loss_scale=73.0,\n", + " normalization={\n", + " \"y\": NormalizationConfig(mean=0.0, std=24.901469505465872),\n", + " \"force\": NormalizationConfig(mean=0.0, std=0.5111534595489502),\n", + " },\n", + " ),\n", + " TaskConfig(\n", + " name=\"oc22\",\n", + " train_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/shared/pre-training-datasets/oc22/s2ef-total/train/\"),\n", + " ),\n", + " val_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/shared/pre-training-datasets/oc22/s2ef-total/val_id/\"),\n", + " ),\n", + " energy_loss_scale=1.0,\n", + " force_loss_scale=80.0,\n", + " normalization={\n", + " \"y\": NormalizationConfig(mean=0.0, std=25.229595396538468),\n", + " \"force\": NormalizationConfig(mean=0.0, std=0.25678861141204834),\n", + " },\n", + " ),\n", + " TaskConfig(\n", + " name=\"ani1x\",\n", + " train_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/shared/pre-training-datasets/ani1x/train/\"),\n", + " ),\n", + " val_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/shared/pre-training-datasets/ani1x/val/\"),\n", + " ),\n", + " energy_loss_scale=1.0,\n", + " force_loss_scale=15.0,\n", + " normalization={\n", + " \"y\": NormalizationConfig(mean=0.0, std=2.8700712783472118),\n", + " \"force\": NormalizationConfig(mean=0.0, std=2.131422996520996),\n", + " },\n", + " ),\n", + " TaskConfig(\n", + " name=\"transition1x\",\n", + " train_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/shared/pre-training-datasets/trans1x/train/\"),\n", + " ),\n", + " val_dataset=PretrainDatasetConfig(\n", + " src=Path(\"/shared/pre-training-datasets/trans1x/val/\"),\n", + " ),\n", + " energy_loss_scale=1.0,\n", + " force_loss_scale=14.0,\n", + " normalization={\n", + " \"y\": NormalizationConfig(mean=0.0, std=1.787466168382901),\n", + " \"force\": NormalizationConfig(mean=0.0, std=0.3591422140598297),\n", + " },\n", + " ),\n", + " ]\n", + "\n", + " return config.finalize()\n", + "\n", + "\n", + "config = jmp_l_config()\n", + "print(config)\n", + "\n", + "configs: list[tuple[PretrainConfig, type[PretrainModel]]] = []\n", + "configs.append((config, PretrainModel))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "abae91062ef54c5db60746cff798ddab", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fast dev run: 0%| | 0/1 [00:00]}.\n", + "Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", + "Using 16bit Automatic Mixed Precision (AMP)\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/opt/conda/envs/jmp/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "Running in `fast_dev_run` mode: will run the requested loop using 16 batch(es). Logging and checkpointing is suppressed.\n", + "WARNING:ll.trainer.logging:Logger DummyLogger does not support run_id, ignoring.\n", + "CRITICAL:ll.trainer.trainer:LightningTrainer log directory: None.\n", + "/opt/conda/envs/jmp/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:126: `.fit(ckpt_path=None)` was called without a model. The last model of the previous `fit` call will be used. You can pass `fit(ckpt_path='best')` to use the best model or `fit(ckpt_path='last')` to use the last model. If you pass a value, this warning will be silenced.\n", + "/opt/conda/envs/jmp/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py:186: .fit(ckpt_path=\"last\") is set, but there is no last checkpoint available. No checkpoint will be loaded. HINT: Set `ModelCheckpoint(..., save_last=True)`.\n", + "WARNING:ll.model.modules.wandb:Could not find wandb logger or module to log\n", + "CRITICAL:ll.model.base:Fast dev run detected, setting debug flag to True.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "CRITICAL:jmp.tasks.config:Optimizer: AdamW\n", + "Optimizer kwargs: {}\n", + "Base kwargs: {}\n", + "Param groups: Param group 0:\n", + " Params: 405\n", + " Total param size: 44131328\n", + " Other kwargs: {'lr': 0.0003, 'amsgrad': False, 'weight_decay': 0.1, 'betas': (0.9, 0.95), 'eps': 1e-08, 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None}\n", + "Loading `train_dataloader` to estimate number of stepping batches.\n", + "CRITICAL:jmp.tasks.pretrain.module:Setting max_steps=32 by default.\n", + "\n", + " | Name | Type | Params\n", + "---------------------------------------------------\n", + "0 | embedding | Embedding | 30.7 K\n", + "1 | backbone | GemNetOCBackbone | 38.8 M\n", + "2 | output | Output | 5.3 M \n", + "3 | train_metrics | FMMetrics | 0 \n", + "4 | val_metrics | FMMetrics | 0 \n", + "5 | task_steps | TypedModuleDict | 0 \n", + "---------------------------------------------------\n", + "44.1 M Trainable params\n", + "0 Non-trainable params\n", + "44.1 M Total params\n", + "176.525 Total estimated model params size (MB)\n", + "CRITICAL:jmp.modules.dataset.concat_dataset:Ignoring balancing because `ignore_balancing` is True in `MTSampledDataset.__init__`.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bdbcf5bdecbe4d5f8cb39dc296e4c116", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00 None:\n", + " model = model_cls(config)\n", + " trainer = Trainer(config)\n", + " trainer.fit(model)\n", + "\n", + "\n", + "runner = Runner(run)\n", + "runner.fast_dev_run(configs, n_batches=16)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..0e28cc3 --- /dev/null +++ b/environment.yml @@ -0,0 +1,321 @@ +name: jmp +channels: + - pyg + - pytorch + - nvidia + - conda-forge +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_kmp_llvm + - aiohttp=3.9.3=py311h459d7ec_1 + - aiosignal=1.3.1=pyhd8ed1ab_0 + - alsa-lib=1.2.8=h166bdaf_0 + - annotated-types=0.6.0=pyhd8ed1ab_0 + - appdirs=1.4.4=pyh9f0ad1d_0 + - asttokens=2.4.1=pyhd8ed1ab_0 + - attr=2.5.1=h166bdaf_1 + - attrs=23.2.0=pyh71513ae_0 + - blas=2.116=mkl + - blas-devel=3.9.0=16_linux64_mkl + - brotli=1.1.0=hd590300_1 + - brotli-bin=1.1.0=hd590300_1 + - brotli-python=1.1.0=py311hb755f60_1 + - bzip2=1.0.8=hd590300_5 + - ca-certificates=2024.2.2=hbcca054_0 + - cairo=1.16.0=ha61ee94_1014 + - certifi=2024.2.2=pyhd8ed1ab_0 + - charset-normalizer=3.3.2=pyhd8ed1ab_0 + - click=8.1.7=unix_pyh707e725_0 + - cloudpickle=3.0.0=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - comm=0.2.2=pyhd8ed1ab_0 + - contourpy=1.2.0=py311h9547e67_0 + - coverage=7.4.4=py311h459d7ec_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.4.99=0 + - cuda-runtime=12.1.0=0 + - cycler=0.12.1=pyhd8ed1ab_0 + - dbus=1.13.6=h5008d03_3 + - debugpy=1.8.1=py311hb755f60_0 + - decorator=5.1.1=pyhd8ed1ab_0 + - docker-pycreds=0.4.0=py_0 + - einops=0.7.0=pyhd8ed1ab_1 + - exceptiongroup=1.2.0=pyhd8ed1ab_2 + - executing=2.0.1=pyhd8ed1ab_0 + - expat=2.6.2=h59595ed_0 + - fastcore=1.5.29=pyhd8ed1ab_0 + - ffmpeg=4.3=hf484d3e_0 + - fftw=3.3.10=nompi_hc118613_108 + - filelock=3.13.3=pyhd8ed1ab_0 + - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 + - font-ttf-inconsolata=3.000=h77eed37_0 + - font-ttf-source-code-pro=2.038=h77eed37_0 + - font-ttf-ubuntu=0.83=h77eed37_1 + - fontconfig=2.14.2=h14ed4e7_0 + - fonts-conda-ecosystem=1=0 + - fonts-conda-forge=1=0 + - fonttools=4.50.0=py311h459d7ec_0 + - freetype=2.12.1=h267a509_2 + - frozendict=2.4.1=py311h459d7ec_0 + - frozenlist=1.4.1=py311h459d7ec_0 + - fsspec=2024.3.1=pyhca7485f_0 + - gettext=0.21.1=h27087fc_0 + - gitdb=4.0.11=pyhd8ed1ab_0 + - gitpython=3.1.43=pyhd8ed1ab_0 + - glib=2.80.0=hf2295e7_1 + - glib-tools=2.80.0=hde27a5a_1 + - gmp=6.3.0=h59595ed_1 + - gmpy2=2.1.2=py311h6a5fa03_1 + - gnutls=3.6.13=h85f3911_1 + - graphite2=1.3.13=h59595ed_1003 + - gst-plugins-base=1.22.0=h4243ec0_2 + - gstreamer=1.22.0=h25f0c4b_2 + - gstreamer-orc=0.4.38=hd590300_0 + - harfbuzz=6.0.0=h8e241bc_0 + - icu=70.1=h27087fc_0 + - idna=3.6=pyhd8ed1ab_0 + - importlib-metadata=7.1.0=pyha770c72_0 + - importlib_metadata=7.1.0=hd8ed1ab_0 + - importlib_resources=6.4.0=pyhd8ed1ab_0 + - iniconfig=2.0.0=pyhd8ed1ab_0 + - ipykernel=6.29.3=pyhd33586a_0 + - ipython=8.22.2=pyh707e725_0 + - ipywidgets=8.1.2=pyhd8ed1ab_0 + - jack=1.9.22=h11f4161_0 + - jedi=0.19.1=pyhd8ed1ab_0 + - jinja2=3.1.3=pyhd8ed1ab_0 + - joblib=1.3.2=pyhd8ed1ab_0 + - jpeg=9e=h166bdaf_2 + - jsonschema=4.21.1=pyhd8ed1ab_0 + - jsonschema-specifications=2023.12.1=pyhd8ed1ab_0 + - jupyter_client=8.6.1=pyhd8ed1ab_0 + - jupyter_core=5.7.2=py311h38be061_0 + - jupyterlab_widgets=3.0.10=pyhd8ed1ab_0 + - keyutils=1.6.1=h166bdaf_0 + - kiwisolver=1.4.5=py311h9547e67_1 + - krb5=1.20.1=h81ceb04_0 + - lame=3.100=h166bdaf_1003 + - lcms2=2.15=hfd0df8a_0 + - ld_impl_linux-64=2.40=h41732ed_0 + - lerc=4.0.0=h27087fc_0 + - libabseil=20240116.1=cxx17_h59595ed_2 + - libblas=3.9.0=16_linux64_mkl + - libbrotlicommon=1.1.0=hd590300_1 + - libbrotlidec=1.1.0=hd590300_1 + - libbrotlienc=1.1.0=hd590300_1 + - libcap=2.67=he9d0100_0 + - libcblas=3.9.0=16_linux64_mkl + - libclang=15.0.7=default_h127d8a8_5 + - libclang13=15.0.7=default_h5d6823c_5 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.9.0.20=0 + - libcups=2.3.3=h36d4200_3 + - libcurand=10.3.5.119=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdb=6.2.32=h9c3ff4c_0 + - libdeflate=1.17=h0b41bf4_0 + - libedit=3.1.20191231=he28a2e2_2 + - libevent=2.1.10=h28343ad_4 + - libexpat=2.6.2=h59595ed_0 + - libffi=3.4.2=h7f98852_5 + - libflac=1.4.3=h59595ed_0 + - libgcc-ng=13.2.0=h807b86a_5 + - libgcrypt=1.10.3=hd590300_0 + - libgfortran-ng=13.2.0=h69a702a_5 + - libgfortran5=13.2.0=ha4646dd_5 + - libglib=2.80.0=hf2295e7_1 + - libgomp=13.2.0=h807b86a_5 + - libgpg-error=1.48=h71f35ed_0 + - libhwloc=2.9.1=hd6dc26d_0 + - libiconv=1.17=hd590300_2 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - liblapack=3.9.0=16_linux64_mkl + - liblapacke=3.9.0=16_linux64_mkl + - libllvm14=14.0.6=hcd5def8_4 + - libllvm15=15.0.7=hadd5161_1 + - libnpp=12.0.2.50=0 + - libnsl=2.0.1=hd590300_0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libogg=1.3.4=h7f98852_1 + - libopus=1.3.1=h7f98852_1 + - libpng=1.6.43=h2797004_0 + - libpq=15.3=hbcd7760_1 + - libprotobuf=4.25.3=h08a7969_0 + - libsndfile=1.2.2=hc60ed4a_1 + - libsodium=1.0.18=h36c2ea0_1 + - libsqlite=3.45.2=h2797004_0 + - libstdcxx-ng=13.2.0=h7e041cc_5 + - libsystemd0=253=h8c4010b_1 + - libtiff=4.5.0=h6adf6a1_2 + - libtool=2.4.7=h27087fc_0 + - libudev1=253=h0b41bf4_1 + - libuuid=2.38.1=h0b41bf4_0 + - libvorbis=1.3.7=h9c3ff4c_0 + - libwebp-base=1.3.2=hd590300_0 + - libxcb=1.13=h7f98852_1004 + - libxcrypt=4.4.36=hd590300_1 + - libxkbcommon=1.5.0=h79f4944_1 + - libxml2=2.10.3=hca2bb57_4 + - libzlib=1.2.13=hd590300_5 + - lightning=2.2.1=pyhd8ed1ab_0 + - lightning-utilities=0.11.2=pyhd8ed1ab_0 + - llvm-openmp=15.0.7=h0cdce71_0 + - llvmlite=0.42.0=py311ha6695c7_1 + - lovely-numpy=0.2.11=pyhd8ed1ab_0 + - lovely-tensors=0.1.15=pyhd8ed1ab_0 + - lz4-c=1.9.4=hcb278e6_0 + - markupsafe=2.1.5=py311h459d7ec_0 + - matplotlib=3.8.3=py311h38be061_0 + - matplotlib-base=3.8.3=py311h54ef318_0 + - matplotlib-inline=0.1.6=pyhd8ed1ab_0 + - mkl=2022.1.0=h84fe81f_915 + - mkl-devel=2022.1.0=ha770c72_916 + - mkl-include=2022.1.0=h84fe81f_915 + - mpc=1.3.1=hfe3b2da_0 + - mpfr=4.2.1=h9458935_0 + - mpg123=1.32.4=h59595ed_0 + - mpmath=1.3.0=pyhd8ed1ab_0 + - multidict=6.0.5=py311h459d7ec_0 + - munkres=1.1.4=pyh9f0ad1d_0 + - mysql-common=8.0.33=hf1915f5_6 + - mysql-libs=8.0.33=hca2cd23_6 + - nbformat=5.10.3=pyhd8ed1ab_0 + - nbval=0.11.0=pyhd8ed1ab_0 + - ncurses=6.4.20240210=h59595ed_0 + - nest-asyncio=1.6.0=pyhd8ed1ab_0 + - nettle=3.6=he412f7d_0 + - networkx=3.2.1=pyhd8ed1ab_0 + - nspr=4.35=h27087fc_0 + - nss=3.98=h1d7d5a4_0 + - numba=0.59.1=py311h96b013e_0 + - numpy=1.26.4=py311h64a7726_0 + - openh264=2.1.1=h780b84a_0 + - openjpeg=2.5.0=hfec8fc6_2 + - openssl=3.1.5=hd590300_0 + - packaging=24.0=pyhd8ed1ab_0 + - pandas=2.2.1=py311h320fe9a_0 + - parso=0.8.3=pyhd8ed1ab_0 + - pathtools=0.1.2=py_1 + - patsy=0.5.6=pyhd8ed1ab_0 + - pcre2=10.43=hcad00b1_0 + - pexpect=4.9.0=pyhd8ed1ab_0 + - pickleshare=0.7.5=py_1003 + - pillow=9.4.0=py311h50def17_1 + - pip=24.0=pyhd8ed1ab_0 + - pixman=0.43.2=h59595ed_0 + - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1 + - platformdirs=4.2.0=pyhd8ed1ab_0 + - plotly=5.19.0=pyhd8ed1ab_0 + - pluggy=1.4.0=pyhd8ed1ab_0 + - ply=3.11=py_1 + - prompt-toolkit=3.0.42=pyha770c72_0 + - protobuf=4.25.3=py311h7b78aeb_0 + - psutil=5.9.8=py311h459d7ec_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pulseaudio=16.1=hcb278e6_3 + - pulseaudio-client=16.1=h5195f5e_3 + - pulseaudio-daemon=16.1=ha8d29e2_3 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - pydantic=2.6.4=pyhd8ed1ab_0 + - pydantic-core=2.16.3=py311h46250e7_0 + - pyg=2.5.2=py311_torch_2.2.0_cu121 + - pygments=2.17.2=pyhd8ed1ab_0 + - pyparsing=3.1.2=pyhd8ed1ab_0 + - pyqt=5.15.9=py311hf0fb5b6_5 + - pyqt5-sip=12.12.2=py311hb755f60_5 + - pysocks=1.7.1=pyha2e5f31_6 + - pytest=8.1.1=pyhd8ed1ab_0 + - python=3.11.6=hab00c5b_0_cpython + - python-dateutil=2.9.0=pyhd8ed1ab_0 + - python-fastjsonschema=2.19.1=pyhd8ed1ab_0 + - python-tzdata=2024.1=pyhd8ed1ab_0 + - python_abi=3.11=4_cp311 + - pytorch=2.2.2=py3.11_cuda12.1_cudnn8.9.2_0 + - pytorch-cluster=1.6.3=py311_torch_2.2.0_cu121 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-lightning=2.2.1=pyhd8ed1ab_0 + - pytorch-mutex=1.0=cuda + - pytorch-scatter=2.1.2=py311_torch_2.2.0_cu121 + - pytorch-sparse=0.6.18=py311_torch_2.2.0_cu121 + - pytz=2024.1=pyhd8ed1ab_0 + - pyyaml=6.0.1=py311h459d7ec_1 + - pyzmq=25.1.2=py311h34ded2d_0 + - qt-main=5.15.8=h5d23da1_6 + - readline=8.2=h8228510_1 + - referencing=0.34.0=pyhd8ed1ab_0 + - requests=2.31.0=pyhd8ed1ab_0 + - rpds-py=0.18.0=py311h46250e7_0 + - scikit-learn=1.4.1.post1=py311hc009520_0 + - scipy=1.12.0=py311h64a7726_2 + - seaborn=0.13.2=hd8ed1ab_0 + - seaborn-base=0.13.2=pyhd8ed1ab_0 + - sentry-sdk=1.44.1=pyhd8ed1ab_0 + - setproctitle=1.3.3=py311h459d7ec_0 + - setuptools=69.2.0=pyhd8ed1ab_0 + - sip=6.7.12=py311hb755f60_0 + - six=1.16.0=pyh6c4a22f_0 + - smmap=5.0.0=pyhd8ed1ab_0 + - stack_data=0.6.2=pyhd8ed1ab_0 + - statsmodels=0.14.1=py311h1f0f07a_0 + - sympy=1.12=pypyh9d50eac_103 + - tbb=2021.9.0=hf52228f_0 + - tenacity=8.2.3=pyhd8ed1ab_0 + - threadpoolctl=3.4.0=pyhc1e730c_0 + - tk=8.6.13=noxft_h4845f30_101 + - toml=0.10.2=pyhd8ed1ab_0 + - tomli=2.0.1=pyhd8ed1ab_0 + - torchaudio=2.2.2=py311_cu121 + - torchmetrics=1.3.2=pyhd8ed1ab_0 + - torchtriton=2.2.0=py311 + - torchvision=0.17.2=py311_cu121 + - tornado=6.4=py311h459d7ec_0 + - tqdm=4.66.2=pyhd8ed1ab_0 + - traitlets=5.14.2=pyhd8ed1ab_0 + - typing-extensions=4.10.0=hd8ed1ab_0 + - typing_extensions=4.10.0=pyha770c72_0 + - tzdata=2024a=h0c530f3_0 + - urllib3=2.2.1=pyhd8ed1ab_0 + - varname=0.13.0=pyhd8ed1ab_0 + - wandb=0.16.5=pyhd8ed1ab_0 + - wcwidth=0.2.13=pyhd8ed1ab_0 + - wheel=0.43.0=pyhd8ed1ab_1 + - widgetsnbextension=4.0.10=pyhd8ed1ab_0 + - wrapt=1.16.0=py311h459d7ec_0 + - xcb-util=0.4.0=h516909a_0 + - xcb-util-image=0.4.0=h166bdaf_0 + - xcb-util-keysyms=0.4.0=h516909a_0 + - xcb-util-renderutil=0.3.9=h166bdaf_0 + - xcb-util-wm=0.4.1=h516909a_0 + - xkeyboard-config=2.38=h0b41bf4_0 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.1.1=hd590300_0 + - xorg-libsm=1.2.4=h7391055_0 + - xorg-libx11=1.8.4=h0b41bf4_0 + - xorg-libxau=1.0.11=hd590300_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h0b41bf4_2 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h0b41bf4_1003 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.6=h166bdaf_0 + - yaml=0.2.5=h7f98852_2 + - yarl=1.9.4=py311h459d7ec_0 + - zeromq=4.3.5=h59595ed_1 + - zipp=3.17.0=pyhd8ed1ab_0 + - zlib=1.2.13=hd590300_5 + - zstd=1.5.5=hfc55251_0 + - pip: + - ase==3.22.1 + - biopython==1.83 + - deepchem==2.8.0 + - lmdb==1.4.1 + - rdkit==2023.9.5 diff --git a/images/README/large_molecules.png b/images/README/large_molecules.png new file mode 100644 index 0000000..01341ec Binary files /dev/null and b/images/README/large_molecules.png differ diff --git a/images/README/main_figure.png b/images/README/main_figure.png new file mode 100644 index 0000000..0faaae5 Binary files /dev/null and b/images/README/main_figure.png differ diff --git a/images/README/materials.png b/images/README/materials.png new file mode 100644 index 0000000..379e827 Binary files /dev/null and b/images/README/materials.png differ diff --git a/images/README/qm9.png b/images/README/qm9.png new file mode 100644 index 0000000..8e589b0 Binary files /dev/null and b/images/README/qm9.png differ diff --git a/images/README/rmd17.png b/images/README/rmd17.png new file mode 100644 index 0000000..f56c4b2 Binary files /dev/null and b/images/README/rmd17.png differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..06ca54e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[project] +name = "jmp" +version = "0.0.1" + +[tool.pyright] +deprecateTypingAliases = true +strictListInference = true +strictDictionaryInference = true +strictSetInference = true diff --git a/scripts/generate_metadata.py b/scripts/generate_metadata.py new file mode 100644 index 0000000..c14a38d --- /dev/null +++ b/scripts/generate_metadata.py @@ -0,0 +1,111 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +from pathlib import Path +from typing import assert_never + +import numpy as np +from jmp.datasets.finetune.base import LmdbDataset as FinetuneLmdbDataset +from jmp.datasets.pretrain_lmdb import PretrainDatasetConfig, PretrainLmdbDataset +from torch.utils.data import DataLoader, Dataset +from torch_geometric.data.data import BaseData +from tqdm import tqdm + + +def _gather_metadata(dataset: Dataset[BaseData], num_workers: int, batch_size: int): + loader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=lambda data_list: np.array( + [data.pos.shape[0] for data in data_list] + ), + shuffle=False, + ) + + natoms_list: list[np.ndarray] = [] + for natoms in tqdm(loader, total=len(loader)): + natoms_list.append(natoms) + + natoms = np.concatenate(natoms_list) + return natoms + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--src", + type=Path, + help="Path to the LMDB file or directory containing LMDB files.", + required=True, + ) + parser.add_argument( + "--dest", + type=Path, + help="Where to save the metadata npz file.", + required=False, + ) + parser.add_argument( + "--type", + type=str, + choices=["pretrain", "finetune"], + help="Type of dataset to gather metadata from.", + required=True, + ) + parser.add_argument( + "--num-workers", + type=int, + help="Number of workers to use for data loading.", + default=32, + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size to use for data loading.", + default=256, + ) + args = parser.parse_args() + + # Parse and validate arguments + src: Path = args.src + dest: Path | None = args.dest + dataset_type: str = args.type + num_workers: int = args.num_workers + batch_size: int = args.batch_size + + if dest is None: + dest = src / "metadata.npz" + + assert src.exists(), f"{src} does not exist" + assert src.is_file() or src.is_dir(), f"{src} is not a file or directory" + + assert dest.suffix == ".npz", f"{dest} is not a .npz file" + assert not dest.exists(), f"{dest} already exists" + + assert dataset_type in ("pretrain", "finetune"), f"{dataset_type} is not valid" + + # Load dataset + match dataset_type: + case "pretrain": + dataset = PretrainLmdbDataset(PretrainDatasetConfig(src=src)) + case "finetune": + dataset = FinetuneLmdbDataset(src=src) + case _: + assert_never(dataset_type) + + # Gather metadata + natoms = _gather_metadata(dataset, num_workers, batch_size) + assert natoms.shape[0] == len(dataset), f"{natoms.shape[0]=} != {len(dataset)=}" + + # Save metadata + np.savez(dest, natoms=natoms) + + +if __name__ == "__main__": + main() diff --git a/src/jmp/__init__.py b/src/jmp/__init__.py new file mode 100644 index 0000000..7e1665b --- /dev/null +++ b/src/jmp/__init__.py @@ -0,0 +1,7 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" diff --git a/src/jmp/configs/finetune/jmp_l.py b/src/jmp/configs/finetune/jmp_l.py new file mode 100644 index 0000000..1424873 --- /dev/null +++ b/src/jmp/configs/finetune/jmp_l.py @@ -0,0 +1,99 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from jmp.lightning import GradientClippingConfig + +from ...models.gemnet.config import BackboneConfig +from ...tasks.config import AdamWConfig +from ...tasks.finetune import FinetuneConfigBase +from ...tasks.finetune.base import ( + CheckpointBestConfig, + EarlyStoppingConfig, + RLPConfig, + WarmupCosRLPConfig, +) +from ...utils.param_specific_util import make_parameter_specific_optimizer_config + + +def jmp_l_ft_config_( + config: FinetuneConfigBase, + ckpt_path: Path, + ema_backbone: bool = True, + disable_force_output_heads: bool = True, +): + # Set the model trainer settings for maximum performance + config.trainer.precision = "16-mixed" + config.trainer.set_float32_matmul_precision = "medium" + config.trainer.supports_parameter_hooks = False + config.trainer.supports_skip_batch_exception = False + + # Set backbone config + config.backbone = BackboneConfig.large() + config.embedding.embedding_size = config.backbone.emb_size_atom + config.backbone.scale_basis = False + + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + config.trainer.optimizer.log_grad_norm = True + config.trainer.optimizer.gradient_clipping = GradientClippingConfig( + value=1.0, + algorithm="value", + ) + # LR Scheduler settings + config.lr_scheduler = WarmupCosRLPConfig( + warmup_epochs=5, + warmup_start_lr_factor=1.0e-1, + should_restart=False, + max_epochs=32, + min_lr_factor=0.1, + rlp=RLPConfig(patience=3, factor=0.8), + ) + # LLRD Settings + config.parameter_specific_optimizers = make_parameter_specific_optimizer_config( + config, + config.backbone.num_blocks, + { + "embedding": 0.3, + "blocks_0": 0.55, + "blocks_1": 0.40, + "blocks_2": 0.30, + "blocks_3": 0.40, + "blocks_4": 0.55, + "blocks_5": 0.625, + }, + ) + + # Checkpoint loading settings + # We want to use EMA weights from pretraining + config.meta["ckpt_path"] = ckpt_path + config.meta["ema_backbone"] = ema_backbone + + # Set data config + config.num_workers = 8 + + # Base early stopping settings + config.trainer.max_epochs = 500 + config.trainer.max_time = "07:00:00:00" + config.early_stopping = EarlyStoppingConfig( + patience=50, + min_delta=1.0e-8, + min_lr=1.0e-8, + ) + config.ckpt_best = CheckpointBestConfig() + + # If we are not using force output heads, we need to disable them + if disable_force_output_heads: + config.backbone.regress_forces = False + config.backbone.direct_forces = False diff --git a/src/jmp/configs/finetune/matbench.py b/src/jmp/configs/finetune/matbench.py new file mode 100644 index 0000000..e573047 --- /dev/null +++ b/src/jmp/configs/finetune/matbench.py @@ -0,0 +1,87 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from ...modules.transforms.normalize import NormalizationConfig as NC +from ...tasks.config import AdamWConfig +from ...tasks.finetune import MatbenchConfig +from ...tasks.finetune import dataset_config as DC +from ...tasks.finetune.base import PrimaryMetricConfig + +STATS: dict[str, dict[str, NC]] = { + "jdft2d_fold0": {"y": NC(mean=110.63706001904778, std=132.02502987887982)}, + "jdft2d_fold1": {"y": NC(mean=100.05996525195053, std=114.26362221432791)}, + "jdft2d_fold2": {"y": NC(mean=101.59535193788061, std=112.45760038504558)}, + "jdft2d_fold3": {"y": NC(mean=99.43551549230911, std=109.9220303290942)}, + "jdft2d_fold4": {"y": NC(mean=95.50851385805468, std=76.27587565670332)}, + "phonons_fold0": {"y": NC(mean=602.9007780432183, std=471.03858838413055)}, + "phonons_fold1": {"y": NC(mean=613.7996473907606, std=486.75099875453213)}, + "phonons_fold2": {"y": NC(mean=619.3868976573087, std=495.2975486965762)}, + "phonons_fold3": {"y": NC(mean=609.7402387661577, std=462.3438660855412)}, + "phonons_fold4": {"y": NC(mean=595.4547502676089, std=476.8567310885976)}, + "dielectric_fold0": {"y": NC(mean=2.417849270334958, std=2.208662738016193)}, + "dielectric_fold1": {"y": NC(mean=2.3716402963883074, std=2.1271523121706912)}, + "dielectric_fold2": {"y": NC(mean=2.354418196731436, std=1.5712251872961516)}, + "dielectric_fold3": {"y": NC(mean=2.392308273978868, std=2.0724149898647544)}, + "dielectric_fold4": {"y": NC(mean=2.3891527750974495, std=2.011348533899877)}, + "log_gvrh_fold0": {"y": NC(mean=1.5557434474198688, std=0.37307197984408746)}, + "log_gvrh_fold1": {"y": NC(mean=1.5584101768747889, std=0.36743473539736493)}, + "log_gvrh_fold2": {"y": NC(mean=1.55746252819908, std=0.36800038945046654)}, + "log_gvrh_fold3": {"y": NC(mean=1.5543022349873286, std=0.3684552493569905)}, + "log_gvrh_fold4": {"y": NC(mean=1.5595705795473838, std=0.37039750391284176)}, + "log_kvrh_fold0": {"y": NC(mean=1.880001033036957, std=0.36820395518377785)}, + "log_kvrh_fold1": {"y": NC(mean=1.883820392919235, std=0.3679308395031994)}, + "log_kvrh_fold2": {"y": NC(mean=1.883778380784775, std=0.3724392829717956)}, + "log_kvrh_fold3": {"y": NC(mean=1.8828457515367547, std=0.3731179944882516)}, + "log_kvrh_fold4": {"y": NC(mean=1.8862681006404232, std=0.3671596024523317)}, + "perovskites_fold0": {"y": NC(mean=1.4726657310327749, std=0.7384309800882398)}, + "perovskites_fold1": {"y": NC(mean=1.4690728968876414, std=0.736635027626099)}, + "perovskites_fold2": {"y": NC(mean=1.4702980269132337, std=0.7456716470700677)}, + "perovskites_fold3": {"y": NC(mean=1.46773815420175, std=0.7431740904365189)}, + "perovskites_fold4": {"y": NC(mean=1.478002311375268, std=0.7435117654840315)}, + "mp_gap_fold0": {"y": NC(mean=1.236432240252091, std=1.6096424425437108)}, + "mp_gap_fold1": {"y": NC(mean=1.2345678083402052, std=1.6044708412420103)}, + "mp_gap_fold2": {"y": NC(mean=1.2352391374131229, std=1.6058092380465256)}, + "mp_gap_fold3": {"y": NC(mean=1.230066812934386, std=1.6003749533498033)}, + "mp_gap_fold4": {"y": NC(mean=1.2350543114618917, std=1.6035590734723943)}, + "mp_e_form_fold0": {"y": NC(mean=-1.4371594843879998, std=1.1577096884761835)}, + "mp_e_form_fold1": {"y": NC(mean=-1.4372781184639032, std=1.1576872656463288)}, + "mp_e_form_fold2": {"y": NC(mean=-1.4353308741245294, std=1.1568986659292604)}, + "mp_e_form_fold3": {"y": NC(mean=-1.4337824626302396, std=1.1570679204976484)}, + "mp_e_form_fold4": {"y": NC(mean=-1.437067044514929, std=1.1567267481888575)}, +} + + +def jmp_l_matbench_config_( + config: MatbenchConfig, + dataset: DC.MatbenchDataset, + fold: DC.MatbenchFold, + base_path: Path, +): + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # Set up dataset + config.train_dataset = DC.matbench_config(dataset, base_path, "train", fold) + config.val_dataset = DC.matbench_config(dataset, base_path, "val", fold) + config.test_dataset = DC.matbench_config(dataset, base_path, "test", fold) + + # Set up normalization + if (normalization_config := STATS.get(f"{dataset}_{fold}")) is None: + raise ValueError(f"Normalization for {dataset}_{fold} not found") + config.normalization = normalization_config + + # MatBench specific settings + config.dataset = dataset + config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min") diff --git a/src/jmp/configs/finetune/md22.py b/src/jmp/configs/finetune/md22.py new file mode 100644 index 0000000..6ff4ab7 --- /dev/null +++ b/src/jmp/configs/finetune/md22.py @@ -0,0 +1,82 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from ...modules.transforms.normalize import NormalizationConfig as NC +from ...tasks.config import AdamWConfig +from ...tasks.finetune import MD22Config +from ...tasks.finetune import dataset_config as DC +from ...tasks.finetune.base import PrimaryMetricConfig + +STATS: dict[str, dict[str, NC]] = { + "Ac-Ala3-NHMe": { + "y": NC(mean=-26913.953, std=0.35547638), + "force": NC(mean=0.0, std=1.1291506), + }, + "DHA": { + "y": NC(mean=-27383.035, std=0.41342595), + "force": NC(mean=0.0, std=1.1258113), + }, + "stachyose": { + "y": NC(mean=-68463.59, std=0.5940788), + "force": NC(mean=0.0, std=1.1104717), + }, + "AT-AT": { + "y": NC(mean=-50080.08, std=0.47309175), + "force": NC(mean=0.0, std=1.2109985), + }, + "AT-AT-CG-CG": { + "y": NC(mean=-101034.23, std=0.680055), + "force": NC(mean=0.0, std=1.2021886), + }, + "buckyball-catcher": { + "y": NC(mean=-124776.7, std=0.64662045), + "force": NC(mean=0.0, std=1.0899031), + }, + "double-walled_nanotube": { + "y": NC(mean=-338224.16, std=3.3810701), + "force": NC(mean=0.0, std=1.0137014), + }, +} + + +def jmp_l_md22_config_( + config: MD22Config, + molecule: DC.MD22Molecule, + base_path: Path, +): + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # Set data config + config.batch_size = 4 + + # Set up dataset + config.train_dataset = DC.md22_config(molecule, base_path, "train") + config.val_dataset = DC.md22_config(molecule, base_path, "val") + config.test_dataset = DC.md22_config(molecule, base_path, "test") + + # MD22 specific settings + config.molecule = molecule + config.primary_metric = PrimaryMetricConfig(name="force_mae", mode="min") + + # Gradient forces + config.model_type = "forces" + config.gradient_forces = True + config.trainer.inference_mode = False + + # Set up normalization + if (normalization_config := STATS.get(molecule)) is None: + raise ValueError(f"Normalization for {molecule} not found") + config.normalization = normalization_config diff --git a/src/jmp/configs/finetune/pdbbind.py b/src/jmp/configs/finetune/pdbbind.py new file mode 100644 index 0000000..8a4eb4e --- /dev/null +++ b/src/jmp/configs/finetune/pdbbind.py @@ -0,0 +1,40 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from ...tasks.config import AdamWConfig +from ...tasks.finetune import PDBBindConfig +from ...tasks.finetune import dataset_config as DC +from ...tasks.finetune.base import PrimaryMetricConfig + + +def jmp_l_pdbbind_config_(config: PDBBindConfig, target: str = "y"): + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # Set up dataset + config.train_dataset = DC.pdbbind_config("train") + config.val_dataset = DC.pdbbind_config("val") + config.test_dataset = DC.pdbbind_config("test") + + # PDBBind specific settings + config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min") + + # Make sure we only optimize for the single target + config.graph_scalar_targets = [target] + config.node_vector_targets = [] + config.graph_classification_targets = [] + config.graph_scalar_reduction = {target: "sum"} + + # PDBBind specific settings + config.pbdbind_task = "-logKd/Ki" + config.metrics.report_rmse = True diff --git a/src/jmp/configs/finetune/qm9.py b/src/jmp/configs/finetune/qm9.py new file mode 100644 index 0000000..185b4fe --- /dev/null +++ b/src/jmp/configs/finetune/qm9.py @@ -0,0 +1,71 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from ...modules.transforms.normalize import NormalizationConfig as NC +from ...tasks.config import AdamWConfig +from ...tasks.finetune import QM9Config +from ...tasks.finetune import dataset_config as DC +from ...tasks.finetune.base import PrimaryMetricConfig +from ...tasks.finetune.qm9 import QM9Target, SpatialExtentConfig + +STATS: dict[str, NC] = { + "mu": NC(mean=2.674587, std=1.5054824), + "alpha": NC(mean=75.31013, std=8.164021), + "eps_HMO": NC(mean=-6.5347567, std=0.59702325), + "eps_LUMO": NC(mean=0.323833, std=1.273586), + "delta_eps": NC(mean=6.8585854, std=1.283122), + "R_2_Abs": NC(mean=1189.6819, std=280.0421), + "ZPVE": NC(mean=-0.00052343315, std=0.04904531), + "U_0": NC(mean=0.0028667436, std=1.0965848), + "U": NC(mean=0.0028711546, std=1.0941933), + "H": NC(mean=0.0029801112, std=1.0942822), + "G": NC(mean=0.000976671, std=1.101572), + "c_v": NC(mean=-0.005799451, std=2.2179737), + "U_0_ATOM": NC(mean=-76.15232, std=10.309152), + "U_ATOM": NC(mean=-76.6171, std=10.400515), + "H_ATOM": NC(mean=-77.05511, std=10.474532), + "G_ATOM": NC(mean=-70.87026, std=9.484609), + "A": NC(mean=11.58375, std=2046.5049), + "B": NC(mean=1.40327, std=1.1445134), + "C": NC(mean=1.1256535, std=0.85679144), +} + + +def jmp_l_qm9_config_(config: QM9Config, target: QM9Target, base_path: Path): + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # Set up dataset + config.train_dataset = DC.qm9_config(base_path, "train") + config.val_dataset = DC.qm9_config(base_path, "val") + config.test_dataset = DC.qm9_config(base_path, "test") + + # Set up normalization + if (normalization_config := STATS.get(target)) is None: + raise ValueError(f"Normalization for {target} not found") + config.normalization = {target: normalization_config} + + # QM9 specific settings + config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min") + + # Make sure we only optimize for the target + config.graph_scalar_targets = [target] + config.node_vector_targets = [] + config.graph_classification_targets = [] + config.graph_scalar_reduction = {target: "sum"} + + # Handle R_2_Abs separately + if target == "R_2_Abs": + config.output_head = SpatialExtentConfig() diff --git a/src/jmp/configs/finetune/qmof.py b/src/jmp/configs/finetune/qmof.py new file mode 100644 index 0000000..9ea0a3f --- /dev/null +++ b/src/jmp/configs/finetune/qmof.py @@ -0,0 +1,48 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from ...modules.transforms.normalize import NormalizationConfig as NC +from ...tasks.config import AdamWConfig +from ...tasks.finetune import QMOFConfig +from ...tasks.finetune import dataset_config as DC +from ...tasks.finetune.base import PrimaryMetricConfig + +STATS: dict[str, NC] = { + "y": NC(mean=2.1866251527, std=1.175752521125648), +} + + +def jmp_l_qmof_config_(config: QMOFConfig, base_path: Path, target: str = "y"): + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # Set up dataset + config.train_dataset = DC.qmof_config(base_path, "train") + config.val_dataset = DC.qmof_config(base_path, "val") + config.test_dataset = DC.qmof_config(base_path, "test") + + # Set up normalization + if (normalization_config := STATS.get(target)) is None: + raise ValueError(f"Normalization for {target} not found") + config.normalization = {target: normalization_config} + + # QMOF specific settings + config.primary_metric = PrimaryMetricConfig(name="y_mae", mode="min") + + # Make sure we only optimize for the single target + config.graph_scalar_targets = [target] + config.node_vector_targets = [] + config.graph_classification_targets = [] + config.graph_scalar_reduction = {target: "sum"} diff --git a/src/jmp/configs/finetune/rmd17.py b/src/jmp/configs/finetune/rmd17.py new file mode 100644 index 0000000..d12b5bc --- /dev/null +++ b/src/jmp/configs/finetune/rmd17.py @@ -0,0 +1,119 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from ...modules.transforms.normalize import NormalizationConfig as NC +from ...tasks.config import AdamWConfig +from ...tasks.finetune import RMD17Config +from ...tasks.finetune import dataset_config as DC +from ...tasks.finetune.base import ( + EarlyStoppingConfig, + PrimaryMetricConfig, + RLPConfig, + WarmupCosRLPConfig, +) + +STATS: dict[str, dict[str, NC]] = { + "aspirin": { + "y": NC(mean=-17617.379355234374, std=0.2673998440577667), + "force": NC(mean=0.0, std=1.2733363), + }, + "azobenzene": { + "y": NC(mean=-15553.118351233397, std=0.2866098335926971), + "force": NC(mean=0.0, std=1.2940075), + }, + "benzene": { + "y": NC(mean=-6306.374855859375, std=0.10482645661015047), + "force": NC(mean=0.0, std=0.90774584), + }, + "ethanol": { + "y": NC(mean=-4209.534573266602, std=0.18616576961275716), + "force": NC(mean=0.0, std=1.1929188), + }, + "malonaldehyde": { + "y": NC(mean=-7254.903633896484, std=0.1812291921138577), + "force": NC(mean=0.0, std=1.302443), + }, + "naphthalene": { + "y": NC(mean=-10478.192319667969, std=0.24922674853668708), + "force": NC(mean=0.0, std=1.3102233), + }, + "paracetamol": { + "y": NC(mean=-13998.780924130859, std=0.26963984094801224), + "force": NC(mean=0.0, std=1.2707518), + }, + "salicylic": { + "y": NC(mean=-13472.110348867187, std=0.2437920552529055), + "force": NC(mean=0.0, std=1.3030343), + }, + "toluene": { + "y": NC(mean=-7373.347077485351, std=0.22534282741069667), + "force": NC(mean=0.0, std=1.246547), + }, + "uracil": { + "y": NC(mean=-11266.351949697266, std=0.2227113171300836), + "force": NC(mean=0.0, std=1.3692871), + }, +} + + +def jmp_l_rmd17_config_( + config: RMD17Config, molecule: DC.RMD17Molecule, base_path: Path +): + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # Set data config + config.batch_size = 4 + + # Set up dataset + config.train_dataset = DC.rmd17_config(molecule, base_path, "train") + config.val_dataset = DC.rmd17_config(molecule, base_path, "val") + config.test_dataset = DC.rmd17_config(molecule, base_path, "test") + + # RMD17 specific settings + config.molecule = molecule + config.primary_metric = PrimaryMetricConfig(name="force_mae", mode="min") + + # Gradient forces + config.model_type = "forces" + config.gradient_forces = True + config.trainer.inference_mode = False + + # Set up normalization + if (normalization_config := STATS.get(molecule)) is None: + raise ValueError(f"Normalization for {molecule} not found") + config.normalization = normalization_config + + # We use more conservative early stopping for rMD17 + # (we essentially copy Allegro here). + config.trainer.max_epochs = 100_000 + config.trainer.max_time = "07:00:00:00" + config.early_stopping = EarlyStoppingConfig( + patience=1000, + min_delta=1.0e-8, + min_lr=1.0e-10, + ) + + # We also use a conservative set of hyperparameters + # for ReduceLROnPlateau (again, we copy Allegro here). + # The main difference is that we use a larger patience (25 vs 3). + config.lr_scheduler = WarmupCosRLPConfig( + warmup_epochs=5, + warmup_start_lr_factor=1.0e-1, + should_restart=False, + max_epochs=32, + min_lr_factor=0.1, + rlp=RLPConfig(patience=25, factor=0.8), + ) diff --git a/src/jmp/configs/finetune/spice.py b/src/jmp/configs/finetune/spice.py new file mode 100644 index 0000000..8e7f82b --- /dev/null +++ b/src/jmp/configs/finetune/spice.py @@ -0,0 +1,58 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from ...modules.transforms.normalize import NormalizationConfig as NC +from ...tasks.config import AdamWConfig +from ...tasks.finetune import SPICEConfig +from ...tasks.finetune import dataset_config as DC +from ...tasks.finetune.base import PrimaryMetricConfig + +STATS: dict[str, dict[str, NC]] = { + "dipeptides": { + "y": NC(mean=-31213.615, std=4636.815), + "force": NC(mean=3.3810358e-07, std=0.5386545), + }, + "solvated_amino_acids": { + "y": NC(mean=-60673.68, std=3310.6692), + "force": NC(mean=2.7950014e-07, std=0.81945145), + }, +} + + +def jmp_l_spice_config_(config: SPICEConfig, dataset: DC.SPICEDataset, base_path: Path): + # Optimizer settings + config.optimizer = AdamWConfig( + lr=5.0e-6, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + + # Set data config + config.batch_size = 1 + + # Set up dataset + config.train_dataset = DC.spice_config(dataset, base_path, "train") + config.val_dataset = DC.spice_config(dataset, base_path, "val") + config.test_dataset = DC.spice_config(dataset, base_path, "test") + + # Spice specific settings + config.dataset = dataset + config.primary_metric = PrimaryMetricConfig(name="force_mae", mode="min") + + # Gradient forces + config.model_type = "forces" + config.gradient_forces = True + config.trainer.inference_mode = False + + # Set up normalization + if (normalization_config := STATS.get(dataset)) is None: + raise ValueError(f"Normalization for {dataset} not found") + config.normalization = normalization_config diff --git a/src/jmp/configs/pretrain/jmp_l.py b/src/jmp/configs/pretrain/jmp_l.py new file mode 100644 index 0000000..74b147a --- /dev/null +++ b/src/jmp/configs/pretrain/jmp_l.py @@ -0,0 +1,56 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from jmp.lightning import GradientClippingConfig + +from ...modules.dataset.concat_dataset import MTDatasetConfig +from ...modules.ema import EMAConfig +from ...tasks.config import AdamWConfig +from ...tasks.pretrain import PretrainConfig +from ...tasks.pretrain.module import LinearWarmupCosineAnnealingSchedulerConfig + + +def jmp_l_pt_config_(config: PretrainConfig): + # Set the model trainer settings for maximum performance + config.trainer.precision = "16-mixed" + config.trainer.set_float32_matmul_precision = "medium" + config.trainer.supports_parameter_hooks = False + config.trainer.supports_skip_batch_exception = False + + # Optimizer settings + config.optimizer = AdamWConfig( + lr=3.0e-4, + amsgrad=False, + betas=(0.9, 0.95), + weight_decay=0.1, + ) + config.trainer.optimizer.log_grad_norm = True + config.trainer.optimizer.gradient_clipping = GradientClippingConfig( + value=1.0, + algorithm="norm", + ) + # LR Scheduler settings + config.lr_scheduler = LinearWarmupCosineAnnealingSchedulerConfig( + warmup_steps=2000, + warmup_start_lr_factor=0.2, + min_lr_factor=0.1, + max_epochs=2, + ) + # Regularization settings + config.edge_dropout = 0.1 + # EMA settings + config.ema = EMAConfig(decay=0.99) + + # Set data config + config.num_workers = 8 + + # Set up the JMP MT dataset config and tasks + config.mt_dataset = MTDatasetConfig( + sample_type="temperature", + sample_temperature=2.0, + ) diff --git a/src/jmp/datasets/finetune/__init__.py b/src/jmp/datasets/finetune/__init__.py new file mode 100644 index 0000000..afcd705 --- /dev/null +++ b/src/jmp/datasets/finetune/__init__.py @@ -0,0 +1,11 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .base import LmdbDataset + +__all__ = ["LmdbDataset"] diff --git a/src/jmp/datasets/finetune/base.py b/src/jmp/datasets/finetune/base.py new file mode 100644 index 0000000..9367bfa --- /dev/null +++ b/src/jmp/datasets/finetune/base.py @@ -0,0 +1,244 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import bisect +import pickle +from collections.abc import Generator +from contextlib import ContextDecorator +from functools import cached_property +from logging import getLogger +from pathlib import Path +from typing import Any, TypedDict, cast + +import lmdb +import numpy as np +import torch +import torch_geometric +from torch.utils.data import Dataset +from torch_geometric.data import Data +from typing_extensions import TypeVar, override + +log = getLogger(__name__) + + +class Config(TypedDict): + src: str | Path + + +def _pyg2_data_transform(data: Data): + """ + if we're on the new pyg (2.0 or later) and if the Data stored is in older format + we need to convert the data to the new format + """ + if torch_geometric.__version__ >= "2.0" and "_store" not in data.__dict__: + return Data(**{k: v for k, v in data.__dict__.items() if v is not None}) + + return data + + +T = TypeVar("T", infer_variance=True) + + +class LmdbDataset(Dataset[T], ContextDecorator): + r"""Dataset class to load from LMDB files containing relaxation + trajectories or single point computations. + Useful for Structure to Energy & Force (S2EF), Initial State to + Relaxed State (IS2RS), and Initial State to Relaxed Energy (IS2RE) tasks. + The keys in the LMDB must be integers (stored as ascii objects) starting + from 0 through the length of the LMDB. For historical reasons any key named + "length" is ignored since that was used to infer length of many lmdbs in the same + folder, but lmdb lengths are now calculated directly from the number of keys. + Args: + config (dict): Dataset configuration + """ + + def data_sizes(self, indices: list[int]) -> np.ndarray: + return self.metadata["natoms"][indices] + + @cached_property + def metadata(self) -> dict[str, np.ndarray]: + metadata_path = self.metadata_path + if metadata_path and metadata_path.is_file(): + return np.load(metadata_path, allow_pickle=True) + + raise ValueError(f"Could not find atoms metadata in '{self.metadata_path}'") + + def data_transform(self, data: Data) -> Data: + return data + + def __init__( + self, + src: str | Path, + metadata_path: str | Path | None = None, + ) -> None: + super().__init__() + + self.path = Path(src) + if not self.path.is_file(): + db_paths = sorted(self.path.glob("*.lmdb")) + assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" + else: + assert self.path.suffix == ".lmdb", f"File '{self.path}' is not an LMDB" + db_paths = [self.path] + + self.metadata_path = ( + Path(metadata_path) if metadata_path else self.path / "metadata.npz" + ) + + self.keys: list[list[int]] = [] + self.envs: list[lmdb.Environment] = [] + # Open all the lmdb files + for db_path in db_paths: + cur_env = lmdb.open( + str(db_path.absolute()), + subdir=False, + readonly=True, + lock=False, + readahead=True, + meminit=False, + max_readers=1, + ) + self.envs.append(cur_env) + + # If "length" encoded as ascii is present, use that + length_entry = cur_env.begin().get("length".encode("ascii")) + if length_entry is not None: + num_entries = pickle.loads(length_entry) + else: + # Get the number of stores data from the number of entries + # in the LMDB + num_entries = cur_env.stat()["entries"] + + # Append the keys (0->num_entries) as a list + self.keys.append(list(range(num_entries))) + + keylens = [len(k) for k in self.keys] + self.keylen_cumulative: list[int] = np.cumsum(keylens).tolist() + self.num_samples = sum(keylens) + + def __len__(self) -> int: + return self.num_samples + + @override + def __getitem__(self, idx: int): + # Figure out which db this should be indexed from. + db_idx = bisect.bisect(self.keylen_cumulative, idx) + # Extract index of element within that db. + el_idx = idx + if db_idx != 0: + el_idx = idx - self.keylen_cumulative[db_idx - 1] + assert el_idx >= 0, f"{el_idx=} is not a valid index" + + # Return features. + key = f"{self.keys[db_idx][el_idx]}".encode("ascii") + env = self.envs[db_idx] + data_object_pickled = env.begin().get(key, default=None) + if data_object_pickled is None: + raise KeyError( + f"Key {key=} not found in {env=}. {el_idx=} {db_idx=} {idx=}" + ) + + data_object = _pyg2_data_transform(pickle.loads(cast(Any, data_object_pickled))) + data_object.id = f"{db_idx}_{el_idx}" + return data_object + + def close_db(self) -> None: + for env in self.envs: + env.close() + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.close_db() + + @classmethod + def pre_data_transform(cls, data: Data) -> Data: + if not hasattr(data, "tags"): + data.tags = torch.full_like(data.atomic_numbers, 2) + if not hasattr(data, "natoms"): + data.natoms = data.num_nodes + return data + + @classmethod + def save_indices( + cls, + split_indices: dict[str, np.ndarray], + root_path: Path, + file_name: str | None = None, + ): + # Dump the indices to root/metadata/split_indices.npz + if file_name is None: + split_indices_path = root_path / "metadata" + split_indices_path.mkdir(parents=True, exist_ok=True) + split_indices_path = split_indices_path / "split_indices.npz" + else: + split_indices_path = root_path / "metadata/split_indices" + split_indices_path.mkdir(parents=True, exist_ok=True) + split_indices_path = split_indices_path / file_name + + np.savez(split_indices_path, **split_indices) + split_indices_sizes = { + name: indices.shape[0] for name, indices in split_indices.items() + } + log.critical( + f"Dumped split_indices={split_indices_sizes} to {split_indices_path}" + ) + + @classmethod + def dump_data( + cls, + generator: Generator[Data, None, None], + count: int, + path: Path, + natoms_metadata_additional_path: Path | None = None, + num_per_file: int = 5_000, + ): + natoms_metadata_list: list[int] = [] + # Store each chunk as "data.%04d.lmdb" + for chunk_idx in range((count + num_per_file - 1) // num_per_file): + # Create a new lmdb file + chunk_path = path / f"data.{chunk_idx:04d}.lmdb" + cur_env = lmdb.open( + str(chunk_path.absolute()), + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + num_saved = 0 + with cur_env.begin(write=True) as txn: + # Save the number of entries in this lmdb + length = min(num_per_file, count - chunk_idx * num_per_file) + txn.put("length".encode("ascii"), pickle.dumps(length)) + + # Save the data + for data_idx in range(min(num_per_file, length)): + data = next(generator) + data = cls.pre_data_transform(data) + + # Get the natoms and save it for metadata + natoms_metadata_list.append(data.atomic_numbers.shape[0]) + + # Save the data + txn.put(f"{data_idx}".encode("ascii"), pickle.dumps(data)) + num_saved += 1 + + # Close the lmdb + cur_env.close() + log.critical(f"Saved {num_saved} entries to {chunk_path}") + + # Save the metadata + natoms_metadata = np.array(natoms_metadata_list) + for p in (natoms_metadata_additional_path, path / "metadata.npz"): + if p is None: + continue + p.parent.mkdir(parents=True, exist_ok=True) + np.savez(p, natoms=natoms_metadata) + log.critical(f"Saved metadata to {p}") diff --git a/src/jmp/datasets/finetune/mat_bench.py b/src/jmp/datasets/finetune/mat_bench.py new file mode 100644 index 0000000..0a14b75 --- /dev/null +++ b/src/jmp/datasets/finetune/mat_bench.py @@ -0,0 +1,181 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import logging +from logging import getLogger +from pathlib import Path +from typing import cast + +import numpy as np +import torch +from matbench.bench import MatbenchBenchmark, MatbenchTask +from pymatgen.core.structure import Structure +from torch_geometric.data import Data +from typing_extensions import TypeVar + +from .base import LmdbDataset +from .utils import env + +log = getLogger(__name__) + + +T = TypeVar("T", infer_variance=True) + + +class MatBench(LmdbDataset[T]): + tasks = [ + "matbench_jdft2d", + "matbench_phonons", + "matbench_dielectric", + "matbench_log_gvrh", + "matbench_log_kvrh", + "matbench_mp_is_metal", + "matbench_perovskites", + "matbench_mp_gap", + "matbench_mp_e_form", + ] + folds = [0, 1, 2, 3, 4] + + @classmethod + def download(cls, args: argparse.Namespace): + destination: str | Path = args.destination + random_seed: int = getattr(args, "random_seed", 42) + + global DOWNLOAD_URL, DOWNLOAD_FILENAME + + root_path = Path(destination) + root_path.mkdir(parents=True, exist_ok=True) + + # Make the raw data directory + raw_dir = root_path / "raw" + raw_dir.mkdir(parents=True, exist_ok=True) + + def _dump_split(data: tuple[list[Structure], list[float]], indices: np.ndarray): + inputs, outputs = data + # Make sure the sizes match + assert len(inputs) == len(outputs), f"{len(inputs)=} != {len(outputs)=}" + + # Dump the systems + for idx in indices: + idx = int(idx) + # Get the system and structure data + structure = inputs[idx] + output = outputs[idx] + + atomic_numbers = torch.tensor( + [site.specie.number for site in structure], dtype=torch.long + ) # natoms + cell = torch.tensor( + structure.lattice.matrix, dtype=torch.float + ).unsqueeze(dim=0) # 1 3 3 + pos = torch.tensor( + [site.coords for site in structure], dtype=torch.float + ) # natoms 3 + y = torch.tensor(output) # () + if isinstance(output, bool): + y = y.bool() + else: + y = y.float() + + data_object = Data( + atomic_numbers=atomic_numbers, + pos=pos, + cell=cell, + y=y, + ) + yield data_object + + with env({"MATMINER_DATA": str(raw_dir.absolute())}): + mb = MatbenchBenchmark(autoload=False, subset=cls.tasks) + mb.load() + for task in mb.tasks: + task = cast(MatbenchTask, task) + for fold in cls.folds: + train_val_data = task.get_train_and_val_data(fold, as_type="tuple") + assert isinstance( + train_val_data, tuple + ), f"{type(train_val_data)=} is not tuple" + test_data = task.get_test_data( + fold, as_type="tuple", include_target=True + ) + assert isinstance( + test_data, tuple + ), f"{type(test_data)=} is not tuple" + + # Get the train/val indices + num_train_val_indices = len(train_val_data[0]) + all_indices = np.arange(num_train_val_indices) + np.random.RandomState(random_seed).shuffle(all_indices) + num_indices = len(all_indices) + num_train = int(num_indices * 0.9) + num_val = num_indices - num_train + split_indices = { + "train": all_indices[:num_train], + "val": all_indices[num_train:], + } + # Test is a separate dataset, so we don't need to split it + test_indices = np.arange(len(test_data[0])) + split_indices["test"] = test_indices + + # Make sure the splits add up + assert ( + num_train + num_val == num_train_val_indices + ), f"{num_train=} + {num_val=} != {num_train_val_indices=}" + + # Dump the indices to root/metadata/split_indices.npz + cls.save_indices( + split_indices, + root_path, + f"{task.dataset_name}_{fold}.npz", + ) + + # Convert the raw data to LMDB + log.info("Converting raw data to LMDB") + + # Make the processed data directory + lmdb_path = root_path / "lmdb" / task.dataset_name / str(fold) + lmdb_path.mkdir(parents=True, exist_ok=True) + + # Dump the frames + for split, indices in split_indices.items(): + path = lmdb_path / split + path.mkdir(parents=True, exist_ok=True) + + data = train_val_data if split != "test" else test_data + + cls.dump_data( + _dump_split(data, indices), + count=indices.shape[0], + path=path, + natoms_metadata_additional_path=root_path + / "metadata" + / split + / f"{task.dataset_name}_{fold}.npz", + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + + # Add a subparser for the download command + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument( + "--destination", type=Path, required=True, help="Path to save the dataset" + ) + download_parser.add_argument( + "--random-seed", type=int, default=42, help="Random seed" + ) + download_parser.set_defaults(func=MatBench.download) + + args = parser.parse_args() + + args.func(args) diff --git a/src/jmp/datasets/finetune/md22.py b/src/jmp/datasets/finetune/md22.py new file mode 100644 index 0000000..6bb30ec --- /dev/null +++ b/src/jmp/datasets/finetune/md22.py @@ -0,0 +1,197 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import logging +import os +from logging import getLogger +from pathlib import Path +from typing import TypedDict + +import numpy as np +import torch +from torch_geometric.data import Data +from tqdm.auto import tqdm +from typing_extensions import TypeVar, override + +from ...modules.transforms import update_units_transform +from .base import LmdbDataset + +log = getLogger(__name__) + +T = TypeVar("T", infer_variance=True) + + +class MoleculeInfo(TypedDict): + train_size: int + + +DOWNLOAD_ROOT = "http://www.quantum-machine.org/gdml/repo/datasets/" +VAL_SIZE = 0.05 + + +class MD22(LmdbDataset[T]): + molecules: dict[str, MoleculeInfo] = { + "Ac-Ala3-NHMe": {"train_size": 6000}, + "DHA": {"train_size": 8000}, + "stachyose": {"train_size": 8000}, + "AT-AT": {"train_size": 3000}, + "AT-AT-CG-CG": {"train_size": 2000}, + "buckyball-catcher": {"train_size": 600}, + "double-walled_nanotube": {"train_size": 800}, + } + + @classmethod + def download(cls, args: argparse.Namespace, molecule: str): + destination: str | Path = args.destination + random_seed: int = getattr(args, "random_seed", 42) + + global DOWNLOAD_ROOT, VAL_SIZE + + if (info := cls.molecules.get(molecule)) is None: + raise ValueError(f"{molecule=} is not a valid MD22 molecule name.") + + # Create root directory + root_path = Path(destination) + root_path.mkdir(parents=True, exist_ok=True) + + # Download dataset + dl_path = root_path / "raw/" + dl_path.mkdir(parents=True, exist_ok=True) + npz_file = dl_path / f"md22_{molecule}.npz" + if not npz_file.exists(): + log.info(f"Downloading {npz_file}") + _ = os.system(f"wget -q {DOWNLOAD_ROOT}md22_{molecule}.npz -P {dl_path}") + log.info(f"Downloaded {npz_file}") + + # Load data + """ + NPZ file data for Ac-Ala3-NHMe: + {'E': array[85109, 1] x∈[-6.207e+05, -6.206e+05] μ=-6.207e+05 σ=8.204, + 'E_max': array(-620623.81595481), + 'E_mean': array(-620662.71173186), + 'E_min': array(-620726.00266174), + 'E_var': array(67.30187507), + 'F': array[85109, 42, 3] n=10723734 x∈[-221.883, 216.102] μ=-1.702e-09 σ=26.039, + 'F_max': array(216.10170499), + 'F_mean': array(-1.70237435e-09), + 'F_min': array(-221.88345657), + 'F_var': array(678.01604927), + 'R': array[85109, 42, 3] n=10723734 x∈[-7.323, 7.873] μ=0.008 σ=2.187, + 'code_version': array('0.4.18.dev1', dtype=' Data: + data = super().pre_data_transform(data) + data = update_units_transform(data, ["y", "force"], from_="kcal/mol", to="eV") + return data + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + + # Add a subparser for the download command + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument( + "--destination", type=Path, required=True, help="Path to save the dataset" + ) + download_parser.add_argument( + "--random-seed", type=int, default=42, help="Random seed" + ) + download_parser.set_defaults(func=MD22.download) + + args = parser.parse_args() + + pbar = tqdm(MD22.molecules.keys()) + for molecule in pbar: + pbar.set_description(molecule) + args.func(args, molecule=molecule) diff --git a/src/jmp/datasets/finetune/qm9.py b/src/jmp/datasets/finetune/qm9.py new file mode 100644 index 0000000..6231752 --- /dev/null +++ b/src/jmp/datasets/finetune/qm9.py @@ -0,0 +1,192 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import logging +import os +from logging import getLogger +from pathlib import Path +from typing import TypedDict, cast + +import numpy as np +import torch +from torch_geometric.data import Data +from typing_extensions import TypeVar + +from .base import LmdbDataset + +log = getLogger(__name__) + +T = TypeVar("T", infer_variance=True) + + +class QM9Item(TypedDict): + idx: torch.Tensor # (1,) + name: str + z: torch.Tensor # (n_atoms,) + x: torch.Tensor # (n_atoms, n_features) + pos: torch.Tensor # (n_atoms, 3) + edge_index: torch.Tensor # (2, n_edges) + edge_attr: torch.Tensor # (n_edges, n_edge_features) + y: torch.Tensor # (1, n_targets) + + +DOWNLOAD_URL = "https://data.pyg.org/datasets/qm9_v3.zip" +DOWNLOAD_FILENAME = "qm9_v3.zip" +NUM_VAL = 10_000 +NUM_TEST = 10_000 + + +class QM9(LmdbDataset[T]): + targets = [ + "mu", # dipole_moment + "alpha", # isotropic_polarizability + "eps_HOMO", # homo + "eps_LUMO", # lumo + "delta_eps", # homo_lumo_gap + "R_2_Abs", # electronic_spatial_extent + "ZPVE", # zpve + "U_0", # energy_U0 + "U", # energy_U + "H", # enthalpy_H + "G", # free_energy + "c_v", # heat_capacity + "U_0_ATOM", # atomization_energy_U0 + "U_ATOM", # atomization_energy_U + "H_ATOM", # atomization_enthalpy_H + "G_ATOM", # atomization_free_energy + "A", # rotational_constant_A + "B", # rotational_constant_B + "C", # rotational_constant_C + ] + + @classmethod + def download(cls, args: argparse.Namespace): + destination: str | Path = args.destination + random_seed: int = getattr(args, "random_seed", 42) + + global DOWNLOAD_URL, DOWNLOAD_FILENAME + + root_path = Path(destination) + root_path.mkdir(parents=True, exist_ok=True) + + # Make the raw data directory + raw_dir = root_path / "raw" + raw_dir.mkdir(parents=True, exist_ok=True) + + # Download the raw data + raw_file = raw_dir / DOWNLOAD_FILENAME + if not raw_file.exists(): + log.info("Downloading raw data") + _ = os.system(f"wget -q -O {raw_file} {DOWNLOAD_URL}") + + # Unzip the raw data + log.info("Unzipping raw data") + _ = os.system(f"unzip {raw_file} -d {raw_dir}") + else: + log.info("Raw data already downloaded") + + # Load the raw data + data = torch.load(raw_dir / "qm9_v3.pt") + assert isinstance(data, list), f"{type(data)=} is not list" + data = cast(list[QM9Item], data) + + # Create the splits + num_indices = len(data) + num_val = NUM_VAL + num_test = NUM_TEST + num_train = num_indices - num_val - num_test + + # Get the indices for each split (80/10/10 train/val/test) + all_indices = np.arange(num_indices) + np.random.RandomState(random_seed).shuffle(all_indices) + + split_indices = { + "train": all_indices[:num_train], + "val": all_indices[num_train : num_train + num_val], + "test": all_indices[num_train + num_val :], + } + + # Make sure the splits add up + assert ( + len(split_indices["train"]) + + len(split_indices["val"]) + + len(split_indices["test"]) + == num_indices + ), "Indices don't add up" + split_sizes = { + name: indices.shape[0] for name, indices in split_indices.items() + } + log.critical(f"Created the following splits: {split_sizes}") + + # Dump the indices to root/metadata/split_indices.npz + cls.save_indices(split_indices, root_path) + + def _dump_idx(indices: np.ndarray): + nonlocal data + + # Dump the systems + for idx in indices: + idx = int(idx) + # Get the system and structure data + system = data[idx] + + atomic_numbers = system["z"].clone().long() # natoms + pos = system["pos"].clone().float() # natoms 3 + y = { + target: system["y"][0][i] + .clone() + .float() # "y" is a tensor of shape (1, num_targets) + for i, target in enumerate(cls.targets) + } + + data_object = Data(atomic_numbers=atomic_numbers, pos=pos, **y) + yield data_object + + # Convert the raw data to LMDB + log.info("Converting raw data to LMDB") + + # Make the processed data directory + lmdb_path = root_path / "lmdb" + lmdb_path.mkdir(parents=True, exist_ok=True) + + # Dump the frames + for split, indices in split_indices.items(): + path = lmdb_path / split + path.mkdir(parents=True, exist_ok=True) + + cls.dump_data( + _dump_idx(indices), + count=indices.shape[0], + path=path, + natoms_metadata_additional_path=root_path + / "metadata" + / split + / "metadata.npz", + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + + # Add a subparser for the download command + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument( + "--destination", type=Path, required=True, help="Path to save the dataset" + ) + download_parser.add_argument( + "--random-seed", type=int, default=42, help="Random seed" + ) + download_parser.set_defaults(func=QM9.download) + + args = parser.parse_args() + + args.func(args) diff --git a/src/jmp/datasets/finetune/qmof.py b/src/jmp/datasets/finetune/qmof.py new file mode 100644 index 0000000..22f230d --- /dev/null +++ b/src/jmp/datasets/finetune/qmof.py @@ -0,0 +1,319 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import json +import logging +import os +from logging import getLogger +from pathlib import Path +from typing import TypedDict, cast + +import numpy as np +import torch +from torch_geometric.data import Data +from typing_extensions import TypeVar + +from .base import LmdbDataset +from .utils import atomic_symbol_to_element + +log = getLogger(__name__) + +T = TypeVar("T", infer_variance=True) + + +# region Types +# region Structure +class Lattice(TypedDict): + matrix: list[list[float]] + a: float + b: float + c: float + alpha: float + beta: float + gamma: float + volume: float + + +class Properties(TypedDict): + pbe_ddec_sum_bond_order: float + pbe_ddec_charge: float + pbe_cm5_charge: float + pbe_bader_charge: float + pbe_magmom: float + pbe_ddec_spin_density: float + pbe_bader_spin_density: float + + +class Species(TypedDict): + element: str + occu: int + + +class Site(TypedDict): + species: list[Species] + abc: list[float] + xyz: list[float] + label: str + properties: Properties + + +class Structure(TypedDict): + module: str + structure_class: str + charge: int + lattice: Lattice + sites: list[Site] + + +class QMOFStructureDict(TypedDict): + qmof_id: str + name: str + structure: Structure + + +# endregion + +# region System + + +class Mofid(TypedDict): + mofid: None + mofkey: None + smiles_nodes: list[str] + smiles_linkers: list[str] + smiles: str + topology: None + + +class Symmetry(TypedDict): + spacegroup: str + spacegroupnumber: int + spacegroupcrystal: str + pointgroup: int + + +class Info(TypedDict): + formula: str + formula_reduced: str + mofid: Mofid + natoms: int + pld: float + lcd: float + density: float + volume: float + symmetry: Symmetry + synthesized: bool + source: str + doi: str + + +class InputsPbe(TypedDict): + theory: str + pseudopotentials: list[str] + encut: int + kpoints: list[int] + gamma: bool + spin: bool + + +class Inputs(TypedDict): + pbe: InputsPbe + + +class OutputsPbe(TypedDict): + energy_total: float + energy_vdw: float + energy_elec: float + net_magmom: int + bandgap: float + cbm: float + vbm: float + directgap: bool + bandgap_spins: list[float] + cbm_spins: list[float] + vbm_spins: list[float] + directgap_spins: list[bool] + + +class Outputs(TypedDict): + pbe: OutputsPbe + + +class QMOFSystemDict(TypedDict): + qmof_id: str + name: str + info: Info + inputs: Inputs + outputs: Outputs + + +# endregion +# endregion + +DOWNLOAD_URL = "https://figshare.com/ndownloader/articles/13147324/versions/14" +DOWNLOAD_FILENAME = "13147324.zip" + + +class QMOF(LmdbDataset[T]): + @classmethod + def download(cls, args: argparse.Namespace): + destination: str | Path = args.destination + random_seed: int = getattr(args, "random_seed", 42) + + global DOWNLOAD_URL, DOWNLOAD_FILENAME + + root_path = Path(destination) + root_path.mkdir(parents=True, exist_ok=True) + + # Make the raw data directory + raw_dir = root_path / "raw" + raw_dir.mkdir(parents=True, exist_ok=True) + + # Download the raw data + raw_file = raw_dir / DOWNLOAD_FILENAME + if not raw_file.exists(): + log.info("Downloading raw data") + _ = os.system(f"wget -q -O {raw_file} {DOWNLOAD_URL}") + + # Unzip the raw data + log.info("Unzipping raw data") + _ = os.system(f"unzip {raw_file} -d {raw_dir}") + + # Unzip the "qmof_database.zip" file + log.info("Unzipping qmof_database.zip") + _ = os.system(f"unzip {raw_dir / 'qmof_database.zip'} -d {raw_dir}") + else: + log.info("Raw data already downloaded") + + # Load the raw data + log.info("Loading raw data") + with open(raw_dir / "qmof_database" / "qmof.json") as f: + qmof_data = json.load(f) + + with open(raw_dir / "qmof_database" / "qmof_structure_data.json") as f: + qmof_structure_data = json.load(f) + + # Both of these are lists of dicts and should be the same length + assert isinstance(qmof_data, list), f"{type(qmof_data)=} is not list" + assert isinstance( + qmof_structure_data, list + ), f"{type(qmof_structure_data)=} is not list" + assert len(qmof_data) == len( + qmof_structure_data + ), f"{len(qmof_data)=} != {len(qmof_structure_data)=}" + qmof_data = cast(list[QMOFSystemDict], qmof_data) + qmof_structure_data = cast(list[QMOFStructureDict], qmof_structure_data) + + # Get the indices for each split (80/10/10 train/val/test) + all_indices = np.arange(len(qmof_data)) + np.random.RandomState(random_seed).shuffle(all_indices) + num_indices = len(all_indices) + num_train = int(num_indices * 0.8) + num_test = int(num_indices * 0.1) + num_val = num_indices - num_train - num_test + split_indices = { + "train": all_indices[:num_train], + "val": all_indices[num_train : num_train + num_val], + "test": all_indices[num_train + num_val :], + } + + # Make sure the splits add up + assert ( + num_train + num_val + num_test == num_indices + ), f"{num_train=} + {num_val=} + {num_test=} != {num_indices=}" + + # Dump the indices to root/metadata/split_indices.npz + cls.save_indices(split_indices, root_path) + + def _dump_idx(indices: np.ndarray): + nonlocal qmof_data, qmof_structure_data + + # Dump the systems + for idx in indices: + idx = int(idx) + # Get the system and structure data + system = qmof_data[idx] + structure = qmof_structure_data[idx] + + atomic_numbers = torch.tensor( + [ + atomic_symbol_to_element[site["label"]] + for site in structure["structure"]["sites"] + ], + dtype=torch.long, + ) # natoms + pos = torch.tensor( + [site["xyz"] for site in structure["structure"]["sites"]], + dtype=torch.float, + ) # natoms 3 + cell = torch.tensor( + structure["structure"]["lattice"]["matrix"], + dtype=torch.float, + ).unsqueeze(dim=0) # 1 3 3 + band_gap = torch.tensor( + system["outputs"]["pbe"]["bandgap"], + dtype=torch.float, + ) # () + energy_total = torch.tensor( + system["outputs"]["pbe"]["energy_total"], + dtype=torch.float, + ) # () + + data_object = Data( + atomic_numbers=atomic_numbers, + pos=pos, + cell=cell, + y=band_gap, + energy_total=energy_total, + sid=system["qmof_id"], + ) + yield data_object + + # Convert the raw data to LMDB + log.info("Converting raw data to LMDB") + + # Make the processed data directory + lmdb_path = root_path / "lmdb" + lmdb_path.mkdir(parents=True, exist_ok=True) + + # Dump the frames + for split, indices in split_indices.items(): + path = lmdb_path / split + path.mkdir(parents=True, exist_ok=True) + + cls.dump_data( + _dump_idx(indices), + count=indices.shape[0], + path=path, + natoms_metadata_additional_path=root_path + / "metadata" + / split + / "metadata.npz", + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + + # Add a subparser for the download command + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument( + "--destination", type=Path, required=True, help="Path to save the dataset" + ) + download_parser.add_argument( + "--random-seed", type=int, default=42, help="Random seed" + ) + download_parser.set_defaults(func=QMOF.download) + + args = parser.parse_args() + + args.func(args) diff --git a/src/jmp/datasets/finetune/rmd17.py b/src/jmp/datasets/finetune/rmd17.py new file mode 100644 index 0000000..fec81bf --- /dev/null +++ b/src/jmp/datasets/finetune/rmd17.py @@ -0,0 +1,192 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import logging +import os +from logging import getLogger +from pathlib import Path + +import numpy as np +import torch +from torch_geometric.data import Data +from tqdm import tqdm +from typing_extensions import TypeVar, override + +from ...modules.transforms.units import update_units_transform +from .base import LmdbDataset + +log = getLogger(__name__) + +T = TypeVar("T", infer_variance=True) + +DOWNLOAD_URL = "https://figshare.com/ndownloader/articles/12672038/versions/3" +DOWNLOAD_FILENAME = "12672038.zip" +TRAIN_SIZE = 1000 +VAL_SIZE = 50 + + +class MD17(LmdbDataset[T]): + name_mappings = { + "napthalene": "naphthalene", + "salicylic acid": "salicylic", + "salicylic-acid": "salicylic", + "salicylic_acid": "salicylic", + } + molecules = { + "aspirin", + "azobenzene", + "benzene", + "ethanol", + "malonaldehyde", + "naphthalene", + "paracetamol", + "salicylic", + "toluene", + "uracil", + } + + @classmethod + def download(cls, args: argparse.Namespace, molecule: str): + destination: str | Path = args.destination + random_seed: int = getattr(args, "random_seed", 42) + + global DOWNLOAD_URL, DOWNLOAD_FILENAME, TRAIN_SIZE, VAL_SIZE + molecule = cls.name_mappings.get(molecule, molecule) + + if molecule not in cls.molecules: + raise ValueError(f"Invalid molecule: {molecule}") + + root_path = Path(destination) + root_path.mkdir(parents=True, exist_ok=True) + + # Make the raw data directory + raw_dir = root_path / "raw" + raw_dir.mkdir(parents=True, exist_ok=True) + + # Download the raw data + raw_file = raw_dir / DOWNLOAD_FILENAME + if not raw_file.exists(): + log.info("Downloading raw data") + _ = os.system(f"wget -q -O {raw_file} {DOWNLOAD_URL}") + + # Unzip the raw data + log.info("Unzipping raw data") + _ = os.system(f"unzip {raw_file} -d {raw_dir}") + + # Untar the rmd17 data + log.info("Untarring rmd17 data") + _ = os.system(f"tar -xjf {raw_dir / 'rmd17.tar.bz2'} -C {raw_dir}") + else: + log.info("Raw data already downloaded") + + # Load the raw data + path = raw_dir / f"rmd17/npz_data/rmd17_{molecule}.npz" + log.info(f"Loading raw data from {path}") + data = np.load(path) + + atomic_numbers = torch.from_numpy(data["nuclear_charges"]).long() # natoms + pos = torch.from_numpy(data["coords"]).float() # nframes natoms 3 + y = torch.from_numpy(data["energies"]).float() # nframes + force = torch.from_numpy(data["forces"]).float() # nframes natoms 3 + + n_frames, _, _ = pos.shape + + # Shuffle the frames + frame_indices = np.arange(n_frames) + np.random.RandomState(random_seed).shuffle(frame_indices) + + # Split into train/test based on train_size + train_size = TRAIN_SIZE + test_size = n_frames - train_size + + # Split train into train/val using 95/5 split + val_size = VAL_SIZE + train_size -= VAL_SIZE + + # Make sure the frames add up + assert train_size + val_size + test_size == n_frames, "Frames don't add up" + + # Create train/val/test splits + split_indices = { + "train": frame_indices[:train_size], + "val": frame_indices[train_size : train_size + val_size], + "test": frame_indices[train_size + val_size :], + } + split_sizes = { + name: indices.shape[0] for name, indices in split_indices.items() + } + log.critical(f"Created the following splits: {split_sizes}") + + # Dump the indices to root/metadata/split_indices.npz + cls.save_indices(split_indices, root_path, f"{molecule}.npz") + + def _dump_frames(indices: np.ndarray): + nonlocal atomic_numbers, pos, y, force + + # Dump the frames + for frame_idx in indices: + data_object = Data( + atomic_numbers=atomic_numbers.clone(), + pos=pos[frame_idx].clone(), + y=y[frame_idx].clone(), + force=force[frame_idx].clone(), + sid=torch.tensor(frame_idx, dtype=torch.long), + ) + yield data_object + + # Make lmdb directory + lmdb_path = root_path / "lmdb" / molecule + lmdb_path.mkdir(parents=True, exist_ok=True) + + # Dump the frames + log.info("Converting raw data to LMDB") + for split, indices in split_indices.items(): + path = lmdb_path / split + path.mkdir(parents=True, exist_ok=True) + + cls.dump_data( + _dump_frames(indices), + count=indices.shape[0], + path=path, + natoms_metadata_additional_path=root_path + / "metadata" + / split + / f"{molecule}.npz", + ) + + @override + @classmethod + def pre_data_transform(cls, data: Data) -> Data: + data = super().pre_data_transform(data) + data = update_units_transform(data, ["y", "force"], from_="kcal/mol", to="eV") + return data + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + + # Add a subparser for the download command + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument( + "--destination", type=Path, required=True, help="Path to save the dataset" + ) + download_parser.add_argument( + "--random-seed", type=int, default=42, help="Random seed" + ) + download_parser.set_defaults(func=MD17.download) + + args = parser.parse_args() + + pbar = tqdm(list(MD17.molecules)) + for molecule in pbar: + pbar.set_description(molecule) + args.func(args, molecule=molecule) diff --git a/src/jmp/datasets/finetune/spice.py b/src/jmp/datasets/finetune/spice.py new file mode 100644 index 0000000..c64d570 --- /dev/null +++ b/src/jmp/datasets/finetune/spice.py @@ -0,0 +1,219 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import logging +import os +from logging import getLogger +from pathlib import Path +from typing import Any, TypedDict + +import h5py +import numpy as np +import torch +from torch_geometric.data import Data +from tqdm import tqdm +from typing_extensions import TypeVar, override + +from ...modules.transforms.units import update_units_transform +from .base import LmdbDataset + +log = getLogger(__name__) + +T = TypeVar("T", infer_variance=True) + +TRAIN_SIZE = 0.8 +TEST_SIZE = 0.1 + + +def _assert_dataset(dataset: Any) -> h5py.Dataset: + if not isinstance(dataset, h5py.Dataset): + raise ValueError(f"{dataset=} is not a dataset") + + return dataset + + +class DatasetInfo(TypedDict): + hdf5: str + + +class SPICE(LmdbDataset[T]): + datasets: dict[str, DatasetInfo] = { + "dipeptides": {"hdf5": "https://fm-datasets.s3.amazonaws.com/dipeptides.h5"}, + "solvated_amino_acids": { + "hdf5": "https://fm-datasets.s3.amazonaws.com/solvated_amino_acids.h5" + }, + } + + @classmethod + def download(cls, args: argparse.Namespace, dataset: str): + destination: str | Path = args.destination + random_seed: int = getattr(args, "random_seed", 42) + + global TRAIN_SIZE, TEST_SIZE + + if (info := cls.datasets.get(dataset)) is None: + raise ValueError(f"{dataset=} is not a valid SPICE dataset name.") + + # Create root directory + root_path = Path(destination) + root_path.mkdir(parents=True, exist_ok=True) + + # Download dataset + dl_path = root_path / "raw/" + dl_path.mkdir(parents=True, exist_ok=True) + + # Download HDF5 file + h5_path = dl_path / info["hdf5"].split("/")[-1] + if not h5_path.exists(): + log.info(f"Downloading {info['hdf5']} to {h5_path}") + _ = os.system(f"wget -q {info['hdf5']} -O {h5_path}") + else: + log.info(f"Found {h5_path}") + + # Load the h5 file + with h5py.File(h5_path, "r") as f: + # First, flatten the dataset to get a list of all conformers + all_conformers: list[tuple[str, int]] = [] + for mol_idx, mol in f.items(): + if not isinstance(mol, h5py.Group): + continue + conformations = mol["conformations"] + if not isinstance(conformations, h5py.Dataset): + continue + for conf_idx in range(conformations.shape[0]): + all_conformers.append((mol_idx, conf_idx)) + + # Get the indices for each split (80/10/10 train/val/test) + all_indices = np.arange(len(all_conformers)) + np.random.RandomState(random_seed).shuffle(all_indices) + num_indices = len(all_indices) + num_train = int(num_indices * TRAIN_SIZE) + num_test = int(num_indices * TEST_SIZE) + num_val = num_indices - num_train - num_test + + split_indices = { + "train": all_indices[:num_train], + "val": all_indices[num_train : num_train + num_val], + "test": all_indices[num_train + num_val :], + } + + # Make sure the splits add up + assert ( + num_train + num_val + num_test == num_indices + ), f"{num_train=} + {num_val=} + {num_test=} != {num_indices=}" + + # Dump the indices to root/metadata/split_indices.npz + cls.save_indices(split_indices, root_path, dataset) + + def _dump_idx(indices: np.ndarray): + nonlocal all_conformers, f + + # Dump the systems + for idx in indices: + idx = int(idx) + + # Get the molecule and conformer indices + mol_idx, conf_idx = all_conformers[idx] + + # Get the molecule + mol = f[mol_idx] + assert isinstance(mol, h5py.Group), f"{mol=} is not a group" + + atomic_numbers = atomic_numbers = torch.from_numpy( + np.array(_assert_dataset(mol.get("atomic_numbers"))) + ).long() # n_atoms + y = torch.from_numpy( + np.array(_assert_dataset(mol.get("dft_total_energy"))[conf_idx]) + ).float() # () + formation_energy = torch.from_numpy( + np.array(_assert_dataset(mol.get("formation_energy"))[conf_idx]) + ).float() # () + force = torch.from_numpy( + np.array( + _assert_dataset(mol.get("dft_total_gradient"))[conf_idx] + ) + ).float() # n_atoms 3 + pos = torch.from_numpy( + np.array(_assert_dataset(mol.get("conformations"))[conf_idx]) + ).float() # n_atoms 3 + + data_object = Data( + atomic_numbers=atomic_numbers, + pos=pos, + force=force, + y=y, + formation_energy=formation_energy, + sid=f"{mol_idx}__{conf_idx}", + ) + yield data_object + + # Convert the raw data to LMDB + log.info("Converting raw data to LMDB") + + # Make the processed data directory + lmdb_path = root_path / "lmdb" / dataset + lmdb_path.mkdir(parents=True, exist_ok=True) + + # Dump the frames + for split, indices in split_indices.items(): + path = lmdb_path / split + path.mkdir(parents=True, exist_ok=True) + + cls.dump_data( + _dump_idx(indices), + count=indices.shape[0], + path=path, + natoms_metadata_additional_path=root_path + / "metadata" + / split + / f"{dataset}.npz", + ) + + @override + @classmethod + def pre_data_transform(cls, data: Data) -> Data: + data = super().pre_data_transform(data) + data = update_units_transform( + data, ["y", "force", "formation_energy"], from_="hartree", to="eV" + ) + data = update_units_transform( + data, attributes=["pos"], from_="bohr", to="angstrom" + ) + data = update_units_transform( + data, + attributes=["force"], + from_="bohr", + to="angstrom", + reciprocal=True, + ) + return data + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + + # Add a subparser for the download command + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument( + "--destination", type=Path, required=True, help="Path to save the dataset" + ) + download_parser.add_argument( + "--random-seed", type=int, default=42, help="Random seed" + ) + download_parser.set_defaults(func=SPICE.download) + + args = parser.parse_args() + + pbar = tqdm(list(SPICE.datasets)) + for dataset in pbar: + pbar.set_description(dataset) + args.func(args, dataset=dataset) diff --git a/src/jmp/datasets/finetune/utils.py b/src/jmp/datasets/finetune/utils.py new file mode 100644 index 0000000..4063992 --- /dev/null +++ b/src/jmp/datasets/finetune/utils.py @@ -0,0 +1,152 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +from contextlib import contextmanager + +atomic_symbol_to_element = { + "n": 0, + "H": 1, + "He": 2, + "Li": 3, + "Be": 4, + "B": 5, + "C": 6, + "N": 7, + "O": 8, + "F": 9, + "Ne": 10, + "Na": 11, + "Mg": 12, + "Al": 13, + "Si": 14, + "P": 15, + "S": 16, + "Cl": 17, + "Ar": 18, + "K": 19, + "Ca": 20, + "Sc": 21, + "Ti": 22, + "V": 23, + "Cr": 24, + "Mn": 25, + "Fe": 26, + "Co": 27, + "Ni": 28, + "Cu": 29, + "Zn": 30, + "Ga": 31, + "Ge": 32, + "As": 33, + "Se": 34, + "Br": 35, + "Kr": 36, + "Rb": 37, + "Sr": 38, + "Y": 39, + "Zr": 40, + "Nb": 41, + "Mo": 42, + "Tc": 43, + "Ru": 44, + "Rh": 45, + "Pd": 46, + "Ag": 47, + "Cd": 48, + "In": 49, + "Sn": 50, + "Sb": 51, + "Te": 52, + "I": 53, + "Xe": 54, + "Cs": 55, + "Ba": 56, + "La": 57, + "Ce": 58, + "Pr": 59, + "Nd": 60, + "Pm": 61, + "Sm": 62, + "Eu": 63, + "Gd": 64, + "Tb": 65, + "Dy": 66, + "Ho": 67, + "Er": 68, + "Tm": 69, + "Yb": 70, + "Lu": 71, + "Hf": 72, + "Ta": 73, + "W": 74, + "Re": 75, + "Os": 76, + "Ir": 77, + "Pt": 78, + "Au": 79, + "Hg": 80, + "Tl": 81, + "Pb": 82, + "Bi": 83, + "Po": 84, + "At": 85, + "Rn": 86, + "Fr": 87, + "Ra": 88, + "Ac": 89, + "Th": 90, + "Pa": 91, + "U": 92, + "Np": 93, + "Pu": 94, + "Am": 95, + "Cm": 96, + "Bk": 97, + "Cf": 98, + "Es": 99, + "Fm": 100, + "Md": 101, + "No": 102, + "Lr": 103, + "Rf": 104, + "Db": 105, + "Sg": 106, + "Bh": 107, + "Hs": 108, + "Mt": 109, + "Ds": 110, + "Rg": 111, + "Cn": 112, + "Nh": 113, + "Fl": 114, + "Mc": 115, + "Lv": 116, + "Ts": 117, + "Og": 118, +} + + +@contextmanager +def env(env_vars: dict[str, str]): + """Context manager for setting environment variables""" + # Get the old environment variables for the keys in kwargs + old_values = {key: os.environ.get(key) for key in env_vars} + + # Set the new environment variables + os.environ.update(env_vars) + + try: + yield + finally: + # Restore the old environment variables + for key, value in old_values.items(): + if value is None: + del os.environ[key] + else: + os.environ[key] = value diff --git a/src/jmp/datasets/finetune_pdbbind.py b/src/jmp/datasets/finetune_pdbbind.py new file mode 100644 index 0000000..e10473d --- /dev/null +++ b/src/jmp/datasets/finetune_pdbbind.py @@ -0,0 +1,266 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import lru_cache +from logging import getLogger +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast + +import numpy as np +import torch +from Bio import PDB +from Bio.PDB.PDBExceptions import PDBConstructionWarning +from einops import pack +from jmp.lightning import TypedConfig +from rdkit import Chem +from torch.utils.data import Dataset +from torch_geometric.data.data import BaseData, Data +from typing_extensions import override + +from ..modules import transforms as T + +try: + # Ignore logging warnings from deepchem (of the kind `logging.warning("Skipped loading some...")`) + getLogger("deepchem.models").setLevel("ERROR") + getLogger("deepchem.metalearning").setLevel("ERROR") + getLogger("deepchem.utils").setLevel("ERROR") + # "No normalization for ..." + getLogger("deepchem.feat.molecule_featurizers.rdkit_descriptors").setLevel("ERROR") + + import deepchem.feat as feat + import deepchem.molnet as molnet + import deepchem.splits as splits + from deepchem.data import Dataset as DCDataset + from deepchem.trans import NormalizationTransformer, Transformer + + molnet_err: ImportError | None = None +except ImportError as err: + if TYPE_CHECKING: + raise + molnet_err = err + DCDataset = Any + + +log = getLogger(__name__) + + +class _MolNetDatasetBase(Dataset[BaseData], ABC): + split_idx = { + "train": 0, + "val": 1, + "test": 2, + } + + @override + def __init__( + self, + dataset_output: Callable[ + [], tuple[list[str], tuple[DCDataset, ...], list[Transformer]] + ], + task: str, + split: Literal["train", "val", "test"], + linref_coeffs: torch.Tensor | None = None, + transform: Callable[[BaseData], BaseData] | None = None, + ): + super().__init__() + + if molnet_err is not None: + raise ImportError("`deepchem` library is not available") from molnet_err + + tasks, datasets, transformers = dataset_output() + + self.task_idx = tasks.index(task) + split_idx = self.split_idx[split] + + dataset = datasets[split_idx] + + self.y_mean = torch.tensor(0.0) + self.y_std = torch.tensor(1.0) + if ( + transformer := next( + (t for t in transformers if isinstance(t, NormalizationTransformer)), + None, + ) + ) is not None and transformer.transform_y: + # If the transformer is a NormalizationTransformer, we keep + # the coefficients for later use. + y_means = transformer.y_means + y_stds = transformer.y_stds + + # If y_means and y_stds are not arrays, then self.task_idx must be 0 + if isinstance(y_means, float) or ( + isinstance(y_means, np.ndarray) and y_means.size == 1 + ): + assert self.task_idx == 0, "Only one task is supported" + y_means = [y_means] + + if isinstance(y_stds, float) or ( + isinstance(y_stds, np.ndarray) and y_stds.size == 1 + ): + assert self.task_idx == 0, "Only one task is supported" + y_stds = [y_stds] + + self.y_mean = torch.tensor(y_means[self.task_idx], dtype=torch.float) + self.y_std = torch.tensor(y_stds[self.task_idx], dtype=torch.float) + + y = dataset.y + # If y is a 1D array, then self.task_idx must be 0 + if y.ndim == 1: + assert self.task_idx == 0, "Only one task is supported" + y = y[:, None] + + self.X = cast(list[Chem.rdchem.Mol], dataset.X) + self.y = y[:, self.task_idx] + self.linref_coeffs = linref_coeffs + self.transform = transform + + def __len__(self) -> int: + return len(self.X) + + @abstractmethod + def X_to_data(self, idx: int) -> BaseData: ... + + @override + def __getitem__(self, idx: int) -> BaseData: + data = self.X_to_data(idx) + # If the number of atoms is less than 4, the molecule is not valid + if data.atomic_numbers.shape[0] < 4: + # HACK: Just return the next molecule + return self.__getitem__((idx + 1) % len(self)) + + if self.linref_coeffs is not None: + data = T.atomref_transform(data, {"y": self.linref_coeffs}) + + if (transform := self.transform) is not None: + data = transform(data) + + return data + + +PDBBindTask: TypeAlias = Literal["-logKd/Ki"] + + +_pt = Chem.GetPeriodicTable() + + +class PDBBindDataset(_MolNetDatasetBase): + @override + def __init__( + self, + task: PDBBindTask, + split: Literal["train", "val", "test"], + linref_coeffs: torch.Tensor | None = None, + transform: Callable[[BaseData], BaseData] | None = None, + ): + warnings.filterwarnings("ignore", category=PDBConstructionWarning) + + super().__init__( + lambda: molnet.load_pdbbind( + featurizer=feat.RawFeaturizer(), + # Random splitting is recommended for this dataset. + # See https://deepchem.readthedocs.io/en/latest/api_reference/moleculenet.html#pdbbind-datasets + splitter=splits.RandomSplitter(), + ), + task, + split, + linref_coeffs, + transform, + ) + + @override + def X_to_data(self, idx: int): + ligand_sdf_file, pocket_pdb_file = cast(tuple[str, str], self.X[idx]) + if (ligand_sdf_info := self._get_sdf_info(ligand_sdf_file)) is None: + raise RuntimeError( + f"Failed to extract info from {ligand_sdf_file=} for {idx=}" + ) + if (pocket_pdb_info := self._get_pdb_info(pocket_pdb_file)) is None: + raise RuntimeError( + f"Failed to extract info from {pocket_pdb_file=} for {idx=}" + ) + + ligand_atomic_numbers, ligand_pos = ligand_sdf_info + ligand_atomic_numbers = torch.from_numpy( + ligand_atomic_numbers + ).long() # n_ligand_atoms + ligand_pos = torch.from_numpy(ligand_pos).float() # n_ligand_atoms 3 + + pocket_atomic_numbers, pocket_pos = pocket_pdb_info + pocket_atomic_numbers = torch.from_numpy( + pocket_atomic_numbers + ).long() # n_pocket_atoms + pocket_pos = torch.from_numpy(pocket_pos).float() # n_pocket_atoms 3 + + atomic_numbers, pos = pack([ligand_atomic_numbers, pocket_atomic_numbers], "*") + pos, _ = pack([ligand_pos, pocket_pos], "* p") + tags, _ = pack( + [ + torch.zeros_like(ligand_atomic_numbers), + torch.ones_like(pocket_atomic_numbers), + ], + "*", + ) + + y = self.y[idx] + + data = Data.from_dict( + { + "idx": idx, + "atomic_numbers": atomic_numbers, + "pos": pos, + "tags": tags, + "y": y, + # Save the y_mean and y_std in the data object for metrics + "y_mean": self.y_mean, + "y_std": self.y_std, + } + ) + data = cast(BaseData, data) + return data + + # Function to extract atomic numbers and coordinates from an SDF file + @staticmethod + @lru_cache(maxsize=16) + def _get_sdf_info(sdf_file: str): + suppl = Chem.SDMolSupplier(sdf_file, sanitize=False) + for mol in suppl: + if mol is None: + continue + atoms = mol.GetAtoms() + atomic_numbers = np.array([atom.GetAtomicNum() for atom in atoms]) + coords = mol.GetConformer().GetPositions() + return atomic_numbers, coords + return None + + # Function to extract atomic numbers and coordinates from a PDB file + @staticmethod + @lru_cache(maxsize=16) + def _get_pdb_info(pdb_file: str): + global _pt + + parser = PDB.PDBParser() + structure = parser.get_structure("structure", pdb_file) + for model in structure: + atomic_numbers = [] + coords = [] + for chain in model: + for residue in chain: + for atom in residue: + symbol = atom.element.strip() + symbol = f"{symbol[0].upper()}{symbol[1:].lower()}" + atomic_numbers.append(_pt.GetAtomicNumber(symbol)) + coords.append(atom.get_coord()) + return np.array(atomic_numbers), np.array(coords) + return None + + +class PDBBindConfig(TypedConfig): + task: PDBBindTask = "-logKd/Ki" + split: Literal["train", "val", "test"] diff --git a/src/jmp/datasets/pretrain_lmdb.py b/src/jmp/datasets/pretrain_lmdb.py new file mode 100644 index 0000000..30c176d --- /dev/null +++ b/src/jmp/datasets/pretrain_lmdb.py @@ -0,0 +1,266 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import bisect +import pickle +from collections.abc import Callable, Mapping +from functools import cache +from pathlib import Path +from typing import Any + +import lmdb +import numpy as np +import torch +from jmp.lightning import TypedConfig +from torch.utils.data import Dataset +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ..utils.ocp import pyg2_data_transform + + +class PretrainDatasetConfig(TypedConfig): + src: Path + """Path to the LMDB file or directory containing LMDB files.""" + + metadata_path: Path | None = None + """Path to the metadata npz file containing the number of atoms in each structure.""" + + total_energy: bool | None = None + """Whether to train on total energies.""" + oc20_ref: Path | None = None + """Path to the OC20 reference energies file.""" + lin_ref: Path | None = None + """Path to the linear reference energies file.""" + + def __post_init__(self): + super().__post_init__() + + # If metadata_path is not provided, assume it is src/metadata.npz + if self.metadata_path is None: + self.metadata_path = self.src / "metadata.npz" + + +class PretrainLmdbDataset(Dataset[BaseData]): + r"""Dataset class to load from LMDB files containing relaxation + trajectories or single point computations. + + Useful for Structure to Energy & Force (S2EF), Initial State to + Relaxed State (IS2RS), and Initial State to Relaxed Energy (IS2RE) tasks. + + Args: + config (dict): Dataset configuration + transform (callable, optional): Data transform function. + (default: :obj:`None`) + """ + + def data_sizes(self, indices: list[int]) -> np.ndarray: + return self.atoms_metadata[indices] + + @property + def atoms_metadata(self) -> np.ndarray: + if ( + metadata := next( + ( + self.metadata[k] + for k in ["natoms", "num_nodes"] + if k in self.metadata + ), + None, + ) + ) is None: + raise ValueError( + f"Could not find atoms metadata key in loaded metadata.\n" + f"Available keys: {list(self.metadata.keys())}" + ) + return metadata + + @property + @cache + def metadata(self) -> Mapping[str, np.ndarray]: + metadata_path = getattr(self, "metadata_path", None) + if not metadata_path or not metadata_path.is_file(): + metadata_path = self.config.metadata_path + + if metadata_path and metadata_path.is_file(): + return np.load(metadata_path, allow_pickle=True) + + raise ValueError(f"Could not find atoms metadata in {metadata_path=}.") + + def __init__( + self, + config: PretrainDatasetConfig, + use_referenced_energies: bool = True, + transform: Callable[[BaseData], Any] | None = None, + ): + super(PretrainLmdbDataset, self).__init__() + self.config = config + + self.path = Path(self.config.src) + if not self.path.is_file(): + db_paths = sorted(self.path.glob("*.lmdb")) + assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" + + # self.metadata_path = self.path / "metadata.npz" + + self._keys, self.envs = [], [] + for db_path in db_paths: + self.envs.append(self.connect_db(db_path)) + try: + length = pickle.loads( + self.envs[-1].begin().get("length".encode("ascii")) + ) + except TypeError: + length = self.envs[-1].stat()["entries"] + self._keys.append(list(range(length))) + + keylens = [len(k) for k in self._keys] + self._keylen_cumulative = np.cumsum(keylens).tolist() + self.num_samples = sum(keylens) + + else: + # self.metadata_path = self.path.parent / "metadata.npz" + self.env = self.connect_db(self.path) + self._keys = [ + f"{j}".encode("ascii") for j in range(self.env.stat()["entries"]) + ] + self.num_samples = len(self._keys) + + self.transform = transform + self.lin_ref = self.oc20_ref = None + self.train_total = self.config.total_energy or False + # only needed for oc20 datasets, p is total by default + if self.train_total: + oc20_ref = self.config.oc20_ref + if not oc20_ref: + raise ValueError("oc20_ref must be provided for oc20 datasets") + self.oc20_ref = pickle.load(open(oc20_ref, "rb")) + + if (lin_ref := self.config.lin_ref) is not None and use_referenced_energies: + coeff = np.load(lin_ref, allow_pickle=True)["coeff"] + try: + self.lin_ref = torch.nn.Parameter( + torch.tensor(coeff), requires_grad=False + ) + except BaseException: + self.lin_ref = torch.nn.Parameter( + torch.tensor(coeff[0]), requires_grad=False + ) + + def __len__(self): + return self.num_samples + + @override + def __getitem__(self, idx): + if not self.path.is_file(): + # Figure out which db this should be indexed from. + db_idx = bisect.bisect(self._keylen_cumulative, idx) + # Extract index of element within that db. + el_idx = idx + if db_idx != 0: + el_idx = idx - self._keylen_cumulative[db_idx - 1] + assert el_idx >= 0 + + # Return features. + with self.envs[db_idx].begin(write=False) as txn: + datapoint_pickled = txn.get( + f"{self._keys[db_idx][el_idx]}".encode("ascii") + ) + data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) + data_object.id = f"{db_idx}_{el_idx}" + else: + with self.env.begin(write=False) as txn: + datapoint_pickled = txn.get(self._keys[idx]) + data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) + + if self.transform is not None: + data_object = self.transform(data_object) + # make types consistent + sid = data_object.sid + if isinstance(sid, torch.Tensor): + sid = sid.item() + data_object.sid = sid + if "fid" in data_object: + fid = data_object.fid + if isinstance(fid, torch.Tensor): + fid = fid.item() + data_object.fid = fid + if "bulk" in data_object: + del data_object.bulk + if hasattr(data_object, "y_relaxed"): + attr = "y_relaxed" + elif hasattr(data_object, "y"): + attr = "y" + # if targets are not available, test data is being used + else: + return data_object + + # convert s2ef energies to raw energies + if attr == "y": + # OC20 data + if "p" not in data_object and self.train_total: + assert self.oc20_ref is not None + + randomid = f"random{sid}" + if hasattr(data_object, "task_mask"): + data_object[attr][data_object.task_mask] += self.oc20_ref[randomid] + else: + data_object[attr] += self.oc20_ref[randomid] + data_object.nads = 1 + data_object.p = 0 + + # convert is2re energies to raw energies + else: + if "p" not in data_object and self.train_total: + assert self.oc20_ref is not None + + randomid = f"random{sid}" + data_object[attr] += self.oc20_ref[randomid] + del data_object.force + del data_object.y_init + data_object.nads = 1 + data_object.p = 0 + + if self.lin_ref is not None: + lin_energy = sum(self.lin_ref[data_object.atomic_numbers.long()]) + if hasattr(data_object, "task_mask"): + data_object[attr][data_object.task_mask] -= lin_energy + else: + data_object[attr] -= lin_energy + if "nads" in data_object: + del data_object.nads + if "p" in data_object: + del data_object.p + # to jointly train on p+oc20, need to delete these oc20-only attributes + # ensure otf_graph=1 in your model configuration + # if "edge_index" in data_object: + # del data_object.edge_index + # if "cell_offsets" in data_object: + # del data_object.cell_offsets + # if "distances" in data_object: + # del data_object.distances + + return data_object + + def connect_db(self, lmdb_path=None): + env = lmdb.open( + str(lmdb_path), + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + return env + + def close_db(self): + if not self.path.is_file(): + for env in self.envs: + env.close() + else: + self.env.close() diff --git a/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_dataloader_ase_traj.py b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_dataloader_ase_traj.py new file mode 100644 index 0000000..bc4a300 --- /dev/null +++ b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_dataloader_ase_traj.py @@ -0,0 +1,69 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +# pylint: disable=stop-iteration-return + +import h5py +import numpy as np +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator as SPCalc +from ase.units import Hartree + + +def generator(formula, grp): + """Iterates through a h5 group""" + + energies = grp["wb97x_dz.energy"] + forces = grp["wb97x_dz.forces"] + atomic_numbers = list(grp["atomic_numbers"]) + positions = grp["coordinates"] + fid = 0 + + for energy, force, positions in zip(energies, forces, positions): + # skip if energy/force is nan + if np.isnan(energy) or np.isnan(force).any(): + continue + # get ase atoms object + atoms = Atoms(atomic_numbers, positions=positions) + # convert from hartree to eV and hartree/angstrom to eV/angstrom + energy = energy * Hartree + force = force * Hartree + sp_calc = SPCalc(atoms=atoms, energy=energy, forces=force.tolist()) + sp_calc.implemented_properties = ["energy", "forces"] + atoms.set_calculator(sp_calc) + atoms.set_tags(2 * np.ones(len(atomic_numbers))) + id = (formula, fid) + fid += 1 + + yield id, atoms + + +class Dataloader: + """ + Can iterate through h5 data set for paper #### + + hdf5_file: path to data + only_final: if True, the iterator will only loop through reactant, product and transition + state instead of all configurations for each reaction and return them in dictionaries. + """ + + def __init__(self, hdf5_file, split_keys): + self.hdf5_file = hdf5_file + self.split_keys = split_keys + + def __iter__(self): + with h5py.File(self.hdf5_file, "r") as h5_file: + for key in self.split_keys: + atoms_list = [] + id_list = [] + for id, molecule in generator(key, h5_file[key]): + atoms_list.append(molecule) + id_list.append(id) + assert len(atoms_list) == h5_file[key]["coordinates"].shape[0] + + yield id_list, atoms_list diff --git a/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_linear_ref.py b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_linear_ref.py new file mode 100644 index 0000000..221fd44 --- /dev/null +++ b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_linear_ref.py @@ -0,0 +1,120 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import pickle +from functools import cache +from pathlib import Path + +import multiprocess as mp +import numpy as np +import torch +from jmp.datasets.pretrain_lmdb import PretrainDatasetConfig, PretrainLmdbDataset +from torch_scatter import scatter +from tqdm import tqdm + + +def _compute_mean_std(args: argparse.Namespace): + @cache + def dataset(): + return PretrainLmdbDataset( + PretrainDatasetConfig(src=args.src, lin_ref=args.linref_path) + ) + + def extract_data(idx): + data = dataset()[idx] + y = data.y + na = data.natoms + return (y, na) + + pool = mp.Pool(args.num_workers) + indices = range(len(dataset())) + + outputs = list(tqdm(pool.imap(extract_data, indices), total=len(indices))) + + energies = [y for y, na in outputs] + num_atoms = [na for y, na in outputs] + + energy_mean = np.mean(energies) + energy_std = np.std(energies) + avg_num_atoms = np.mean(num_atoms) + + print( + f"energy_mean: {energy_mean}, energy_std: {energy_std}, average number of atoms: {avg_num_atoms}" + ) + + with open(args.out_path, "wb") as f: + pickle.dump( + { + "energy_mean": energy_mean, + "energy_std": energy_std, + "avg_num_atoms": avg_num_atoms, + }, + f, + ) + + +def _linref(args: argparse.Namespace): + @cache + def dataset(): + return PretrainLmdbDataset(PretrainDatasetConfig(src=args.src)) + + def extract_data(idx): + data = dataset()[idx] + x = ( + scatter( + torch.ones(data.atomic_numbers.shape[0]), + data.atomic_numbers.long(), + dim_size=10, + ) + .long() + .numpy() + ) + y = data.y + return (x, y) + + pool = mp.Pool(args.num_workers) + indices = range(len(dataset())) + + outputs = list(tqdm(pool.imap(extract_data, indices), total=len(indices))) + + features = [x[0] for x in outputs] + targets = [x[1] for x in outputs] + + X = np.vstack(features) + y = targets + + coeff = np.linalg.lstsq(X, y, rcond=None)[0] + np.savez_compressed(args.out_path, coeff=coeff) + print(f"Saved linear reference coefficients to {args.out_path}") + + +def main(): + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(dest="subcommand") + + compute_mean_std_parser = subparsers.add_parser("compute_mean_std") + compute_mean_std_parser.add_argument("--src", type=Path, required=True) + compute_mean_std_parser.add_argument("--out_path", type=Path, required=True) + compute_mean_std_parser.add_argument("--linref_path", type=Path, required=True) + compute_mean_std_parser.add_argument("--num_workers", type=int, default=32) + compute_mean_std_parser.set_defaults(fn=_compute_mean_std) + + linref_parser = subparsers.add_parser("linref") + linref_parser.add_argument("--src", type=Path, required=True) + linref_parser.add_argument("--out_path", type=Path, required=True) + linref_parser.add_argument("--num_workers", type=int, default=32) + linref_parser.set_defaults(fn=_linref) + + args = parser.parse_args() + args.fn(args) + + +if __name__ == "__main__": + main() diff --git a/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_splits.py b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_splits.py new file mode 100644 index 0000000..ebe5d79 --- /dev/null +++ b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_splits.py @@ -0,0 +1,100 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import copy +import pickle +import random + +import h5py + + +def remove_small_systems(h5_file, key_list): + sm_sys_list = [] + sm_sys_frames = [] + for key in key_list: + if h5_file[key]["atomic_numbers"].shape[0] < 4: + sm_sys_list.append(key) + sm_sys_frames.append(h5_file[key]["coordinates"].shape[0]) + + print(f"total frames removed: {sum(sm_sys_frames)}") + updated_keys = [i for i in key_list if i not in sm_sys_list] + return updated_keys + + +def get_split(h5_file, rand_key_list, tot_split_size): + split_keys = [] + counter = 0 + for key in rand_key_list: + split_keys.append(key) + counter += h5_file[key]["coordinates"].shape[0] + if counter >= tot_split_size: + break + print(f"total frames in split: {counter}") + return split_keys + + +def main(args): + ani_h5 = h5py.File(args.input_file, "r") + ani_keys = list(ani_h5.keys()) # get all keys which happen to be unique molecules + updated_ani_keys = remove_small_systems(ani_h5, ani_keys) + rand_ani_keys = copy.deepcopy(updated_ani_keys) + + for i in range(6): + random.shuffle(rand_ani_keys) + + val_keys = get_split(ani_h5, rand_ani_keys, args.split_size) + tmp_keys = [i for i in rand_ani_keys if i not in val_keys] + test_keys = get_split(ani_h5, tmp_keys, args.split_size) + train_keys = [i for i in tmp_keys if i not in test_keys] + + with open(args.train_keys_output, "wb") as f: + pickle.dump(train_keys, f) + with open(args.test_keys_output, "wb") as f: + pickle.dump(val_keys, f) + with open(args.val_keys_output, "wb") as f: + pickle.dump(test_keys, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Split ANI-1x dataset into train, test, and validation sets." + ) + parser.add_argument( + "--input_file", + type=str, + required=True, + help="Path to the input ANI-1x HDF5 file.", + ) + parser.add_argument( + "--split_size", + type=int, + default=495600, + help="Size of each split (default: 495600).", + ) + parser.add_argument( + "--train_keys_output", + type=str, + required=True, + help="Path to save the train keys pickle file.", + ) + parser.add_argument( + "--test_keys_output", + type=str, + required=True, + help="Path to save the test keys pickle file.", + ) + parser.add_argument( + "--val_keys_output", + type=str, + required=True, + help="Path to save the validation keys pickle file.", + ) + + args = parser.parse_args() + main(args) diff --git a/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_write_lmdbs.py b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_write_lmdbs.py new file mode 100644 index 0000000..47584af --- /dev/null +++ b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_write_lmdbs.py @@ -0,0 +1,186 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import multiprocessing as mp +import os +import pickle +from pathlib import Path + +import ase.io +import lmdb +import numpy as np +import torch +from torch_geometric.data import Data +from tqdm import tqdm + +# from ocpmodels.preprocessing import AtomsToGraphs + + +def write_images_to_lmdb(mp_arg): + db_path, samples, sampled_ids, idx, pid, uid_mapping, args = mp_arg + db = lmdb.open( + db_path, + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + pbar = tqdm( + total=len(samples), + position=pid, + desc="Preprocessing data into LMDBs", + ) + for sysid, fid in samples: + # Convert from forumula name to a unique id using uid_mapping + uid = uid_mapping[sysid] + + fid = int(fid) + sid = int(uid) + traj_file = os.path.join(args.data_path, f"{sysid}.traj") + # traj_logs = open(sample, "r").read().splitlines() + # xyz_idx = os.path.splitext(os.path.basename(sample))[0] + # traj_path = os.path.join(args.data_path, f"{xyz_idx}.extxyz") + atoms = ase.io.read(traj_file, index=fid) + data_object = Data( + pos=torch.Tensor(atoms.get_positions()), + atomic_numbers=torch.Tensor(atoms.get_atomic_numbers()), + sid=sid, + fid=fid, + natoms=atoms.get_positions().shape[0], + tags=torch.LongTensor(atoms.get_tags()), + force=torch.Tensor(atoms.get_forces()), + pbc=torch.Tensor(atoms.pbc), + y=atoms.get_potential_energy(), + ) + + txn = db.begin(write=True) + txn.put( + f"{idx}".encode("ascii"), + pickle.dumps(data_object, protocol=-1), + ) + txn.commit() + idx += 1 + sampled_ids.append(f"{sysid},{fid},{uid}" + "\n") + pbar.update(1) + + # Save count of objects in lmdb. + txn = db.begin(write=True) + txn.put("length".encode("ascii"), pickle.dumps(idx, protocol=-1)) + txn.commit() + + db.sync() + db.close() + + return sampled_ids, idx + + +def main(args): + # xyz_logs = glob.glob(os.path.join(args.data_path, "*.txt")) + # if not xyz_logs: + # raise RuntimeError("No *.txt files found. Did you uncompress?") + + # if args.num_workers > len(xyz_logs): + # args.num_workers = len(xyz_logs) + ids_file = os.path.join(args.data_path, f"{args.split}_ids.pkl") + with open(ids_file, "rb") as f: + ids = pickle.load(f) + + # Load the uid mapping: + # Find all *.traj files in the data path + data_path = Path(args.data_path) + uid_mapping = {} + for traj_file in data_path.glob("*.traj"): + uid_mapping[traj_file.stem] = len(uid_mapping) + + # Initialize feature extractor. + """a2g = AtomsToGraphs( + max_neigh=50, + radius=6, + r_energy=not args.test_data, + r_forces=not args.test_data, + r_fixed=True, + r_distances=False, + r_edges=args.get_edges, + )""" + + # Create output directory if it doesn't exist. + os.makedirs(os.path.join(args.out_path), exist_ok=True) + + # Initialize lmdb paths + db_paths = [ + os.path.join(args.out_path, "data.%04d.lmdb" % i) + for i in range(args.num_workers) + ] + + # Chunk the trajectories into args.num_workers splits + chunked_ids = np.array_split(ids, args.num_workers) + + # Extract features + sampled_ids, idx = [[]] * args.num_workers, [0] * args.num_workers + + pool = mp.Pool(args.num_workers) + mp_args = [ + ( + db_paths[i], + chunked_ids[i], + sampled_ids[i], + idx[i], + i, + uid_mapping, + args, + ) + for i in range(args.num_workers) + ] + op = list(zip(*pool.imap(write_images_to_lmdb, mp_args))) + sampled_ids, idx = list(op[0]), list(op[1]) + + # Log sampled image, trajectory trace + for j, i in enumerate(range(args.num_workers)): + ids_log = open(os.path.join(args.out_path, "data_log.%04d.txt" % i), "w") + ids_log.writelines(sampled_ids[j]) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", help="Path to dir containing *.traj files") + parser.add_argument( + "--out-path", + help="Directory to save extracted features. Will create if doesn't exist", + ) + parser.add_argument( + "--split", + help="train, test, or val", + ) + """parser.add_argument( + "--get-edges", + action="store_true", + help="Store edge indices in LMDB, ~10x storage requirement. Default: compute edge indices on-the-fly.", + )""" + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="No. of feature-extracting processes or no. of dataset chunks", + ) + """parser.add_argument( + "--ref-energy", action="store_true", help="Subtract reference energies" + )""" + # parser.add_argument( + # "--test-data", + # action="store_true", + # help="Is data being processed test data?", + # ) + return parser + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_write_traj.py b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_write_traj.py new file mode 100644 index 0000000..e40494f --- /dev/null +++ b/src/jmp/datasets/scripts/ani1x_preprocess/ani1x_write_traj.py @@ -0,0 +1,91 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import multiprocessing as mp +import os +import pickle +import random +import time + +import ase.io + +from .ani1x_dataloader_ase_traj import Dataloader + + +def write_traj(data_i): + id_list, atoms_list = data_i + traj_file = os.path.join(args.traj_dir, id_list[0][0] + ".traj") + # write ase traj file + ase.io.write(traj_file, atoms_list, format="traj") + return id_list + + +def main(args): + start = time.time() + + pool = mp.Pool(args.num_workers) + split_keys = pickle.load(open(args.split_keys, "rb")) + dataloader = Dataloader(args.ani1x_h5, split_keys) + + out_pool = list(pool.imap_unordered(write_traj, dataloader)) + # flatten list of lists + sampled_ids = [item for sublist in out_pool for item in sublist] + + random.shuffle(sampled_ids) + random.shuffle(sampled_ids) + random.shuffle(sampled_ids) + # write pkl files for ids + id_file = os.path.join(args.traj_dir, args.split + "_ids.pkl") + with open(id_file, "wb") as f: + pickle.dump(sampled_ids, f) + + end = time.time() + total_time = end - start + print("Total time: ", total_time) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate ASE trajectory files from ANI-1x dataset." + ) + parser.add_argument( + "--ani1x_h5", type=str, required=True, help="Path to the ANI-1x HDF5 file." + ) + parser.add_argument( + "--split_keys", + type=str, + required=True, + help="Path to the pickle file containing split keys.", + ) + parser.add_argument( + "--split", + type=str, + required=True, + help="Name of the split (e.g., train, test, val).", + ) + parser.add_argument( + "--traj_dir", + type=str, + required=True, + help="Directory to save the ASE trajectory files.", + ) + parser.add_argument( + "--num_workers", + type=int, + default=64, + help="Number of worker processes (default: 64).", + ) + + args = parser.parse_args() + + # Load split keys from pickle file + with open(args.split_keys, "rb") as f: + split_keys = pickle.load(f) + + main(args) diff --git a/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_dataloader_ase_traj.py b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_dataloader_ase_traj.py new file mode 100644 index 0000000..f901b4e --- /dev/null +++ b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_dataloader_ase_traj.py @@ -0,0 +1,120 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +# pylint: disable=stop-iteration-return + +import h5py +import numpy as np +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator as SPCalc + +REFERENCE_ENERGIES = { + 1: -13.62222753701504, + 6: -1029.4130839658328, + 7: -1484.8710358098756, + 8: -2041.8396277138045, + 9: -2712.8213146878606, +} + + +def get_molecular_reference_energy(atomic_numbers): + molecular_reference_energy = 0 + for atomic_number in atomic_numbers: + molecular_reference_energy += REFERENCE_ENERGIES[atomic_number] + + return molecular_reference_energy + + +def generator(formula, rxn, grp): + """Iterates through a h5 group""" + + energies = grp["wB97x_6-31G(d).energy"] + forces = grp["wB97x_6-31G(d).forces"] + atomic_numbers = list(grp["atomic_numbers"]) + positions = grp["positions"] + molecular_reference_energy = get_molecular_reference_energy(atomic_numbers) + fid = 0 + + for energy, force, positions in zip(energies, forces, positions): + # get ase atoms object + atoms = Atoms(atomic_numbers, positions=positions) + sp_calc = SPCalc(atoms=atoms, energy=energy, forces=force.tolist()) + sp_calc.implemented_properties = ["energy", "forces"] + atoms.set_calculator(sp_calc) + atoms.set_tags(2 * np.ones(len(atomic_numbers))) + id = (f"{formula}_{rxn}", fid) + fid += 1 + + """d = { + "rxn": rxn, + "wB97x_6-31G(d).energy": energy.__float__(), + "wB97x_6-31G(d).atomization_energy": energy + - molecular_reference_energy.__float__(), + "wB97x_6-31G(d).forces": force.tolist(), + "positions": positions, + "formula": formula, + "atomic_numbers": atomic_numbers, + }""" + + yield id, atoms + + +class Dataloader: + """ + Can iterate through h5 data set for paper #### + + hdf5_file: path to data + only_final: if True, the iterator will only loop through reactant, product and transition + state instead of all configurations for each reaction and return them in dictionaries. + """ + + def __init__(self, hdf5_file, datasplit="data", only_final=False): + self.hdf5_file = hdf5_file + self.only_final = only_final + + self.datasplit = datasplit + if datasplit: + assert datasplit in [ + "data", + "train", + "val", + "test", + ], "datasplit must be one of 'all', 'train', 'val' or 'test'" + + def __iter__(self): + with h5py.File(self.hdf5_file, "r") as f: + split = f[self.datasplit] + + for formula, grp in split.items(): + for rxn, subgrp in grp.items(): + # reactant = next(generator(formula, rxn, subgrp["reactant"])) + # product = next(generator(formula, rxn, subgrp["product"])) + + """if self.only_final: + transition_state = next( + generator(formula, rxn, subgrp["transition_state"]) + ) + yield { + "rxn": rxn, + "reactant": reactant, + "product": product, + "transition_state": transition_state, + }""" + # yield (reactant, "reactant") + # yield (product, "product") + rxn_atoms_list = [] + id_list = [] + sm_sys = None + for id, molecule in generator(formula, rxn, subgrp): + rxn_atoms_list.append(molecule) + id_list.append(id) + assert len(rxn_atoms_list) == subgrp["positions"].shape[0] + # marking systems that have less than 4 atoms + if subgrp["atomic_numbers"].shape[0] < 4: + sm_sys = (f"{formula}_{rxn}", subgrp["atomic_numbers"].shape[0]) + yield id_list, rxn_atoms_list, sm_sys diff --git a/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_linear_ref.py b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_linear_ref.py new file mode 100644 index 0000000..ea0d45f --- /dev/null +++ b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_linear_ref.py @@ -0,0 +1,123 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import pickle +from functools import cache +from pathlib import Path + +import multiprocess as mp +import numpy as np +import torch +from jmp.datasets.pretrain_lmdb import PretrainDatasetConfig, PretrainLmdbDataset +from torch_scatter import scatter +from tqdm import tqdm + + +def _compute_mean_std(args: argparse.Namespace): + @cache + def dataset(): + return PretrainLmdbDataset( + PretrainDatasetConfig(src=args.src, lin_ref=args.linref_path) + ) + + def extract_data(idx): + data = dataset()[idx] + y = data.y + f = data.force.cpu().numpy() + return (y, f) + + pool = mp.Pool(args.num_workers) + indices = range(len(dataset())) + + outputs = list(tqdm(pool.imap(extract_data, indices), total=len(indices))) + + energies = [y for y, forces in outputs] + forces = np.array([force for y, forces in outputs for force in forces]) + + energy_mean = np.mean(energies) + energy_std = np.std(energies) + force_rms = np.sqrt(np.mean(np.square(forces))) + force_md = np.mean(np.linalg.norm(forces, axis=-1)) + + print( + f"energy_mean: {energy_mean}, energy_std: {energy_std}, force_rms: {force_rms}, force_md: {force_md}" + ) + + with open(args.out_path, "wb") as f: + pickle.dump( + { + "energy_mean": energy_mean, + "energy_std": energy_std, + "force_rms": force_rms, + "force_md": force_md, + }, + f, + ) + + +def _linref(args: argparse.Namespace): + @cache + def dataset(): + return PretrainLmdbDataset(PretrainDatasetConfig(src=args.src)) + + def extract_data(idx): + data = dataset()[idx] + x = ( + scatter( + torch.ones(data.atomic_numbers.shape[0]), + data.atomic_numbers.long(), + dim_size=10, + ) + .long() + .numpy() + ) + y = data.y + f = data.force.cpu().numpy() + return (x, y, f) + + pool = mp.Pool(args.num_workers) + indices = range(len(dataset())) + + outputs = list(tqdm(pool.imap(extract_data, indices), total=len(indices))) + + features = [x[0] for x in outputs] + targets = [x[1] for x in outputs] + + X = np.vstack(features) + y = targets + + coeff = np.linalg.lstsq(X, y, rcond=None)[0] + np.savez_compressed(args.out_path, coeff=coeff) + print(f"Saved linear reference coefficients to {args.out_path}") + + +def main(): + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(dest="subcommand") + + compute_mean_std_parser = subparsers.add_parser("compute_mean_std") + compute_mean_std_parser.add_argument("--src", type=Path) + compute_mean_std_parser.add_argument("--out_path", type=Path) + compute_mean_std_parser.add_argument("--linref_path", type=Path) + compute_mean_std_parser.add_argument("--num_workers", type=int, default=32) + compute_mean_std_parser.set_defaults(fn=_compute_mean_std) + + linref_parser = subparsers.add_parser("linref") + linref_parser.add_argument("--src", type=Path) + linref_parser.add_argument("--out_path", type=Path) + linref_parser.add_argument("--num_workers", type=int, default=32) + linref_parser.set_defaults(fn=_linref) + + args = parser.parse_args() + args.fn(args) + + +if __name__ == "__main__": + main() diff --git a/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_write_lmdbs.py b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_write_lmdbs.py new file mode 100644 index 0000000..53a1a31 --- /dev/null +++ b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_write_lmdbs.py @@ -0,0 +1,174 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import multiprocessing as mp +import os +import pickle + +import ase.io +import lmdb +import numpy as np +import torch +from torch_geometric.data import Data +from tqdm import tqdm + +# from ocpmodels.preprocessing import AtomsToGraphs + + +def write_images_to_lmdb(mp_arg): + db_path, samples, sampled_ids, idx, pid, args = mp_arg + db = lmdb.open( + db_path, + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + pbar = tqdm( + total=len(samples), + position=pid, + desc="Preprocessing data into LMDBs", + ) + for sysid, fid in samples: + fid = int(fid) + sid = int(sysid.split("rxn")[1]) + traj_file = os.path.join(args.data_path, f"{sysid}.traj") + # traj_logs = open(sample, "r").read().splitlines() + # xyz_idx = os.path.splitext(os.path.basename(sample))[0] + # traj_path = os.path.join(args.data_path, f"{xyz_idx}.extxyz") + atoms = ase.io.read(traj_file, index=fid) + data_object = Data( + pos=torch.Tensor(atoms.get_positions()), + atomic_numbers=torch.Tensor(atoms.get_atomic_numbers()), + sid=sid, + fid=fid, + natoms=atoms.get_positions().shape[0], + tags=torch.LongTensor(atoms.get_tags()), + force=torch.Tensor(atoms.get_forces()), + pbc=torch.Tensor(atoms.pbc), + y=atoms.get_potential_energy(), + ) + + txn = db.begin(write=True) + txn.put( + f"{idx}".encode("ascii"), + pickle.dumps(data_object, protocol=-1), + ) + txn.commit() + idx += 1 + sampled_ids.append(f"{sysid},{fid}" + "\n") + pbar.update(1) + + # Save count of objects in lmdb. + txn = db.begin(write=True) + txn.put("length".encode("ascii"), pickle.dumps(idx, protocol=-1)) + txn.commit() + + db.sync() + db.close() + + return sampled_ids, idx + + +def main(args): + # xyz_logs = glob.glob(os.path.join(args.data_path, "*.txt")) + # if not xyz_logs: + # raise RuntimeError("No *.txt files found. Did you uncompress?") + + # if args.num_workers > len(xyz_logs): + # args.num_workers = len(xyz_logs) + ids_file = os.path.join(args.data_path, f"{args.split}_ids.pkl") + with open(ids_file, "rb") as f: + ids = pickle.load(f) + + # Initialize feature extractor. + """a2g = AtomsToGraphs( + max_neigh=50, + radius=6, + r_energy=not args.test_data, + r_forces=not args.test_data, + r_fixed=True, + r_distances=False, + r_edges=args.get_edges, + )""" + + # Create output directory if it doesn't exist. + os.makedirs(os.path.join(args.out_path), exist_ok=True) + + # Initialize lmdb paths + db_paths = [ + os.path.join(args.out_path, "data.%04d.lmdb" % i) + for i in range(args.num_workers) + ] + + # Chunk the trajectories into args.num_workers splits + chunked_ids = np.array_split(ids, args.num_workers) + + # Extract features + sampled_ids, idx = [[]] * args.num_workers, [0] * args.num_workers + + pool = mp.Pool(args.num_workers) + mp_args = [ + ( + db_paths[i], + chunked_ids[i], + sampled_ids[i], + idx[i], + i, + args, + ) + for i in range(args.num_workers) + ] + op = list(zip(*pool.imap(write_images_to_lmdb, mp_args))) + sampled_ids, idx = list(op[0]), list(op[1]) + + # Log sampled image, trajectory trace + for j, i in enumerate(range(args.num_workers)): + ids_log = open(os.path.join(args.out_path, "data_log.%04d.txt" % i), "w") + ids_log.writelines(sampled_ids[j]) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", help="Path to dir containing *.traj files") + parser.add_argument( + "--out_path", + help="Directory to save extracted features. Will create if doesn't exist", + ) + parser.add_argument( + "--split", + help="train, test, or val", + ) + """parser.add_argument( + "--get-edges", + action="store_true", + help="Store edge indices in LMDB, ~10x storage requirement. Default: compute edge indices on-the-fly.", + )""" + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="No. of feature-extracting processes or no. of dataset chunks", + ) + """parser.add_argument( + "--ref-energy", action="store_true", help="Subtract reference energies" + )""" + # parser.add_argument( + # "--test-data", + # action="store_true", + # help="Is data being processed test data?", + # ) + return parser + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_write_traj.py b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_write_traj.py new file mode 100644 index 0000000..50690e9 --- /dev/null +++ b/src/jmp/datasets/scripts/transition1x_preprocess/trans1x_write_traj.py @@ -0,0 +1,82 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import multiprocessing as mp +import os +import pickle +import random +import time +from functools import partial + +import ase.io + +from .trans1x_dataloader_ase_traj import Dataloader + + +def write_traj(data_i, *, args: argparse.Namespace): + id_list, atoms_list, sm_sys = data_i + traj_file = os.path.join(args.traj_dir, id_list[0][0] + ".traj") + # write ase traj file + ase.io.write(traj_file, atoms_list, format="traj") + return id_list, sm_sys + + +def main(args: argparse.Namespace): + start = time.time() + + pool = mp.Pool(args.num_workers) + dataloader = Dataloader(args.transition1x_h5, datasplit=args.split) + + out_pool = list(zip(*pool.imap(partial(write_traj, args=args), dataloader))) + sampled_ids, sm_sys = list(out_pool[0]), list(out_pool[1]) + # flatten list of lists + sampled_ids = [item for sublist in sampled_ids for item in sublist] + + random.shuffle(sampled_ids) + random.shuffle(sampled_ids) + random.shuffle(sampled_ids) + # write plk files for ids and small systems + id_file = os.path.join(args.traj_dir, args.split + "_ids.pkl") + with open(id_file, "wb") as f: + pickle.dump(sampled_ids, f) + sm_sys_file = os.path.join(args.traj_dir, args.split + "_sm_sys.pkl") + with open(sm_sys_file, "wb") as f: + pickle.dump(sm_sys, f) + + end = time.time() + total_time = end - start + print("Total time: ", total_time) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--transition1x_h5", + type=str, + help="Path to the HDF5 file containing the dataset.", + ) + parser.add_argument( + "--traj_dir", + type=str, + help="Directory to save the trajectory files.", + ) + parser.add_argument( + "--split", + type=str, + default="train", + choices=["train", "val", "test"], + ) + parser.add_argument( + "--num_workers", + type=int, + default=os.cpu_count(), + help="Number of worker processes.", + ) + args = parser.parse_args() + main(args) diff --git a/src/jmp/lightning/__init__.py b/src/jmp/lightning/__init__.py new file mode 100644 index 0000000..698d2d6 --- /dev/null +++ b/src/jmp/lightning/__init__.py @@ -0,0 +1,72 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from . import actsave as A +from .actsave import ActSave +from .config import MISSING, AllowMissing, Field, MissingField, PrivateAttr, TypedConfig +from .data import dataset_transform +from .exception import SkipBatch +from .model.base import Base, LightningDataModuleBase, LightningModuleBase +from .model.config import ( + BaseConfig, + CSVLoggingConfig, + EnvironmentConfig, + GradientClippingConfig, + GradientSkippingConfig, + LoggingConfig, + OptimizerConfig, + PythonLogging, + RunnerConfig, + RunnerOutputSaveConfig, + TensorboardLoggingConfig, + TrainerConfig, + WandbLoggingConfig, + WandbWatchConfig, +) +from .modules.normalizer import NormalizerConfig +from .runner import Runner +from .trainer import Trainer +from .util.singleton import Registry, Singleton +from .util.typed import TypedModuleDict, TypedModuleList + +__all__ = [ + "A", + "ActSave", + "MISSING", + "AllowMissing", + "Field", + "MissingField", + "PrivateAttr", + "TypedConfig", + "dataset_transform", + "SkipBatch", + "Base", + "LightningDataModuleBase", + "LightningModuleBase", + "BaseConfig", + "CSVLoggingConfig", + "EnvironmentConfig", + "GradientClippingConfig", + "GradientSkippingConfig", + "LoggingConfig", + "OptimizerConfig", + "PythonLogging", + "RunnerConfig", + "RunnerOutputSaveConfig", + "TensorboardLoggingConfig", + "TrainerConfig", + "WandbLoggingConfig", + "WandbWatchConfig", + "NormalizerConfig", + "Runner", + "Trainer", + "Registry", + "Singleton", + "TypedModuleDict", + "TypedModuleList", +] diff --git a/src/jmp/lightning/_config/missing.py b/src/jmp/lightning/_config/missing.py new file mode 100644 index 0000000..26b7ca8 --- /dev/null +++ b/src/jmp/lightning/_config/missing.py @@ -0,0 +1,195 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast + +from pydantic import BaseModel, Field +from pydantic.config import JsonDict +from pydantic.fields import AliasChoices, AliasPath, FieldInfo, _EmptyKwargs, _Unset +from pydantic.types import Discriminator +from pydantic_core import PydanticCustomError, PydanticUndefined +from typing_extensions import TypeVar, Unpack + + +@dataclass +class AllowMissingAnnotation: + pass + + +MISSING = cast(Any, None) + +T = TypeVar("T", infer_variance=True) +if TYPE_CHECKING: + AllowMissing: TypeAlias = Annotated[T, AllowMissingAnnotation()] +else: + AllowMissing: TypeAlias = Annotated[T | None, AllowMissingAnnotation()] + + +def validate_no_missing_values(model: BaseModel): + for name, field in model.model_fields.items(): + # If the field doesn't have the `AllowMissing` annotation, ignore it. + # (i.e., just let Pydantic do its thing). + allow_missing_annotation = next( + (m for m in field.metadata if isinstance(m, AllowMissingAnnotation)), + None, + ) + if allow_missing_annotation is None: + continue + + # By this point, the field **should** have some value. + if not hasattr(model, name): + raise PydanticCustomError( + "field_not_set", + 'Field "{name}" is missing from the model.', + {"name": name}, + ) + + # Now, we error out if the field is missing. + if getattr(model, name) is None: + raise PydanticCustomError( + "field_MISSING", + 'Field "{name}" is still `MISSING`. Please provide a value for it.', + {"name": name}, + ) + + +def MissingField( # noqa: C901 + default: Any = PydanticUndefined, + *, + default_factory: Callable[[], Any] | None = _Unset, + alias: str | None = _Unset, + alias_priority: int | None = _Unset, + validation_alias: str | AliasPath | AliasChoices | None = _Unset, + serialization_alias: str | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + examples: list[Any] | None = _Unset, + exclude: bool | None = _Unset, + discriminator: str | Discriminator | None = _Unset, + json_schema_extra: JsonDict | Callable[[JsonDict], None] | None = _Unset, + frozen: bool | None = _Unset, + validate_default: bool | None = _Unset, + repr: bool = _Unset, + init: bool | None = _Unset, + init_var: bool | None = _Unset, + kw_only: bool | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + union_mode: Literal["smart", "left_to_right"] = _Unset, + **extra: Unpack[_EmptyKwargs], +) -> Any: + """Usage docs: https://docs.pydantic.dev/2.7/concepts/fields + + Create a field for objects that can be configured. + + Used to provide extra information about a field, either for the model schema or complex validation. Some arguments + apply only to number fields (`int`, `float`, `Decimal`) and some apply only to `str`. + + Note: + - Any `_Unset` objects will be replaced by the corresponding value defined in the `_DefaultValues` dictionary. If a key for the `_Unset` object is not found in the `_DefaultValues` dictionary, it will default to `None` + + Args: + default: Default value if the field is not set. + default_factory: A callable to generate the default value, such as :func:`~datetime.utcnow`. + alias: The name to use for the attribute when validating or serializing by alias. + This is often used for things like converting between snake and camel case. + alias_priority: Priority of the alias. This affects whether an alias generator is used. + validation_alias: Like `alias`, but only affects validation, not serialization. + serialization_alias: Like `alias`, but only affects serialization, not validation. + title: Human-readable title. + description: Human-readable description. + examples: Example values for this field. + exclude: Whether to exclude the field from the model serialization. + discriminator: Field name or Discriminator for discriminating the type in a tagged union. + json_schema_extra: A dict or callable to provide extra JSON schema properties. + frozen: Whether the field is frozen. If true, attempts to change the value on an instance will raise an error. + validate_default: If `True`, apply validation to the default value every time you create an instance. + Otherwise, for performance reasons, the default value of the field is trusted and not validated. + repr: A boolean indicating whether to include the field in the `__repr__` output. + init: Whether the field should be included in the constructor of the dataclass. + (Only applies to dataclasses.) + init_var: Whether the field should _only_ be included in the constructor of the dataclass. + (Only applies to dataclasses.) + kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass. + (Only applies to dataclasses.) + strict: If `True`, strict validation is applied to the field. + See [Strict Mode](../concepts/strict_mode.md) for details. + gt: Greater than. If set, value must be greater than this. Only applicable to numbers. + ge: Greater than or equal. If set, value must be greater than or equal to this. Only applicable to numbers. + lt: Less than. If set, value must be less than this. Only applicable to numbers. + le: Less than or equal. If set, value must be less than or equal to this. Only applicable to numbers. + multiple_of: Value must be a multiple of this. Only applicable to numbers. + min_length: Minimum length for iterables. + max_length: Maximum length for iterables. + pattern: Pattern for strings (a regular expression). + allow_inf_nan: Allow `inf`, `-inf`, `nan`. Only applicable to numbers. + max_digits: Maximum number of allow digits for strings. + decimal_places: Maximum number of decimal places allowed for numbers. + union_mode: The strategy to apply when validating a union. Can be `smart` (the default), or `left_to_right`. + See [Union Mode](standard_library_types.md#union-mode) for details. + extra: (Deprecated) Extra fields that will be included in the JSON schema. + + !!! warning Deprecated + The `extra` kwargs is deprecated. Use `json_schema_extra` instead. + + Returns: + A new [`FieldInfo`][pydantic.fields.FieldInfo]. The return annotation is `Any` so `Field` can be used on + type-annotated fields without causing a type error. + """ + field = Field( + default=default, + default_factory=default_factory, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + examples=examples, + exclude=exclude, + discriminator=discriminator, + json_schema_extra=json_schema_extra, + frozen=frozen, + validate_default=validate_default, + repr=repr, + init=init, + init_var=init_var, + kw_only=kw_only, + pattern=pattern, + strict=strict, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_length=min_length, + max_length=max_length, + union_mode=union_mode, + **extra, + ) + + field = cast(FieldInfo, field) + field.metadata.append(AllowMissingAnnotation()) + + field = cast(Any, field) + return field diff --git a/src/jmp/lightning/actsave.py b/src/jmp/lightning/actsave.py new file mode 100644 index 0000000..8cf3626 --- /dev/null +++ b/src/jmp/lightning/actsave.py @@ -0,0 +1,463 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib +import fnmatch +import tempfile +import uuid +import weakref +from dataclasses import dataclass, field +from functools import cached_property, wraps +from logging import getLogger +from pathlib import Path +from typing import Callable, Generic, Mapping, Union, cast, overload + +import numpy as np +import torch +from lightning.pytorch import LightningModule +from lightning_utilities.core.apply_func import apply_to_collection +from typing_extensions import ParamSpec, TypeVar, override + +log = getLogger(__name__) + +Value = Union[int, float, complex, bool, str, np.ndarray, torch.Tensor] +ValueOrLambda = Union[Value, Callable[..., Value]] + + +def _to_numpy(activation: Value) -> np.ndarray: + if isinstance(activation, np.ndarray): + return activation + if isinstance(activation, torch.Tensor): + activation = activation.detach() + if activation.is_floating_point(): + # NOTE: We need to convert to float32 because [b]float16 is not supported by numpy + activation = activation.float() + return activation.cpu().numpy() + if isinstance(activation, (int, float, complex, str, bool)): + return np.array(activation) + return activation + + +T = TypeVar("T", infer_variance=True) + + +# A wrapper around weakref.ref that allows for primitive types +# To get around errors like: +# TypeError: cannot create weak reference to 'int' object +class WeakRef(Generic[T]): + _ref: Callable[[], T] | None + + def __init__(self, obj: T): + try: + self._ref = cast(Callable[[], T], weakref.ref(obj)) + except TypeError as e: + if "cannot create weak reference" not in str(e): + raise + self._ref = lambda: obj + + def __call__(self) -> T: + if self._ref is None: + raise RuntimeError("WeakRef is deleted") + return self._ref() + + def delete(self): + del self._ref + self._ref = None + + +@dataclass +class Activation: + name: str + ref: WeakRef[ValueOrLambda] | None + transformed: np.ndarray | None = None + + def __post_init__(self): + # Update the `name` to replace `/` with `::` + self.name = self.name.replace("/", "::") + + def __call__(self) -> np.ndarray: + # If we have a transformed value, we return it + if self.transformed is not None: + return self.transformed + + if self.ref is None: + raise RuntimeError("Activation is deleted") + + # If we have a lambda, we need to call it + unrwapped_ref = self.ref() + activation = unrwapped_ref + if callable(unrwapped_ref): + activation = unrwapped_ref() + activation = apply_to_collection(activation, torch.Tensor, _to_numpy) + activation = _to_numpy(activation) + + # Set the transformed value + self.transformed = activation + + # Delete the reference + self.ref.delete() + del self.ref + self.ref = None + + return self.transformed + + @classmethod + def from_value_or_lambda(cls, name: str, value_or_lambda: ValueOrLambda): + return cls(name, WeakRef(value_or_lambda)) + + @classmethod + def from_dict(cls, d: Mapping[str, ValueOrLambda]): + return [cls.from_value_or_lambda(k, v) for k, v in d.items()] + + +Transform = Callable[[Activation], Mapping[str, ValueOrLambda]] + + +def _ensure_supported(): + try: + import torch.distributed as dist + + if dist.is_initialized() and dist.get_world_size() > 1: + raise RuntimeError("Only single GPU is supported at the moment") + except ImportError: + pass + + +P = ParamSpec("P") + + +def _ignore_if_scripting(fn: Callable[P, None]) -> Callable[P, None]: + @wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + if torch.jit.is_scripting(): + return + + _ensure_supported() + fn(*args, **kwargs) + + return wrapper + + +class ActivationSaver: + def __init__( + self, + save_dir: Path, + prefixes_fn: Callable[[], list[str]], + *, + filters: list[str] | None = None, + transforms: list[tuple[str, Transform]] | None = None, + ): + # Create a directory under `save_dir` by autoincrementing + # (i.e., every activation save context, we create a new directory) + # The id = the number of activation subdirectories + self._id = sum(1 for subdir in save_dir.glob("*") if subdir.is_dir()) + save_dir.mkdir(parents=True, exist_ok=True) + + # Add a .activationbase file to the save_dir to indicate that this is an activation base + (save_dir / ".activationbase").touch(exist_ok=True) + + self._save_dir = save_dir / f"{self._id:04d}" + # Make sure `self._save_dir` does not exist and create it + self._save_dir.mkdir(exist_ok=False) + + self._prefixes_fn = prefixes_fn + self._filters = filters + self._transforms = transforms + + def _save_activation(self, activation: Activation): + # Save the activation to self._save_dir / name / {id}.npz, where id is an auto-incrementing integer + file_name = ".".join(self._prefixes_fn() + [activation.name]) + path = self._save_dir / file_name + path.mkdir(exist_ok=True, parents=True) + + # Get the next id and save the activation + id = len(list(path.glob("*.npy"))) + np.save(path / f"{id:04d}.npy", activation()) + + @_ignore_if_scripting + def save( + self, + acts: dict[str, ValueOrLambda] | None = None, + /, + **kwargs: ValueOrLambda, + ): + kwargs.update(acts or {}) + + # Build activations + activations = Activation.from_dict(kwargs) + + transformed_activations: list[Activation] = [] + + for activation in activations: + # Make sure name matches at least one filter if filters are specified + if self._filters is None or any( + fnmatch.fnmatch(activation.name, f) for f in self._filters + ): + self._save_activation(activation) + + # If we have any transforms, we need to apply them + if self._transforms: + # Iterate through transforms and apply them + for name, transform in self._transforms: + # If the transform doesn't match, we skip it + if not fnmatch.fnmatch(activation.name, name): + continue + + # Apply the transform + transform_out = transform(activation) + + # If the transform returns empty, we skip it + if not transform_out: + continue + + # Otherwise, add the transform to the activations + transformed_activations.extend(Activation.from_dict(transform_out)) + + # Now, we save the transformed activations + for transformed_activation in transformed_activations: + self._save_activation(transformed_activation) + + del activations + del transformed_activations + + +class ActSaveProvider: + _saver: ActivationSaver | None = None + _prefixes: list[str] = [] + + def initialize( + self, + save_dir: Path | None = None, + *, + filters: list[str] | None = None, + transforms: list[tuple[str, Transform]] | None = None, + ): + if self._saver is None: + if save_dir is None: + save_dir = Path(tempfile.gettempdir()) / f"actsave-{uuid.uuid4()}" + log.critical(f"No save_dir specified, using {save_dir=}") + self._saver = ActivationSaver( + save_dir, + lambda: self._prefixes, + filters=filters, + transforms=transforms, + ) + + @contextlib.contextmanager + def enabled( + self, + save_dir: Path | None = None, + *, + filters: list[str] | None = None, + transforms: list[tuple[str, Transform]] | None = None, + ): + prev = self._saver + self.initialize(save_dir, filters=filters, transforms=transforms) + try: + yield + finally: + self._saver = prev + + @override + def __init__(self): + super().__init__() + + self._saver = None + self._prefixes = [] + + @contextlib.contextmanager + def context(self, label: str): + if torch.jit.is_scripting(): + yield + return + + if self._saver is None: + yield + return + + _ensure_supported() + + log.debug(f"Entering ActSave context {label}") + self._prefixes.append(label) + try: + yield + finally: + _ = self._prefixes.pop() + + prefix = context + + @overload + def __call__( + self, + acts: dict[str, ValueOrLambda] | None = None, + /, + **kwargs: ValueOrLambda, + ): ... + + @overload + def __call__(self, acts: Callable[[], dict[str, ValueOrLambda]], /): ... + + def __call__( + self, + acts: ( + dict[str, ValueOrLambda] | Callable[[], dict[str, ValueOrLambda]] | None + ) = None, + /, + **kwargs: ValueOrLambda, + ): + if torch.jit.is_scripting(): + return + + if self._saver is None: + return + + if acts is not None and callable(acts): + acts = acts() + self._saver.save(acts, **kwargs) + + save = __call__ + + +@dataclass +class LoadedActivation: + base_dir: Path = field(repr=False) + name: str + num_activations: int = field(init=False) + activation_files: list[Path] = field(init=False, repr=False) + + def __post_init__(self): + if not self.activation_dir.exists(): + raise ValueError(f"Activation dir {self.activation_dir} does not exist") + + # The number of activations = the * of .npy files in the activation dir + self.activation_files = list(self.activation_dir.glob("*.npy")) + # Sort the activation files by the numerical index in the filename + self.activation_files.sort(key=lambda p: int(p.stem)) + self.num_activations = len(self.activation_files) + + @property + def activation_dir(self) -> Path: + return self.base_dir / self.name + + def _load_activation(self, item: int): + activation_path = self.activation_files[item] + if not activation_path.exists(): + raise ValueError(f"Activation {activation_path} does not exist") + return cast(np.ndarray, np.load(activation_path, allow_pickle=True)) + + @overload + def __getitem__(self, item: int) -> np.ndarray: ... + + @overload + def __getitem__(self, item: slice | list[int]) -> list[np.ndarray]: ... + + def __getitem__( + self, item: int | slice | list[int] + ) -> np.ndarray | list[np.ndarray]: + if isinstance(item, int): + return self._load_activation(item) + elif isinstance(item, slice): + return [ + self._load_activation(i) + for i in range(*item.indices(self.num_activations)) + ] + elif isinstance(item, list): + return [self._load_activation(i) for i in item] + else: + raise TypeError(f"Invalid type {type(item)} for item {item}") + + def __iter__(self): + return iter(self[i] for i in range(self.num_activations)) + + def __len__(self): + return self.num_activations + + def all_activations(self): + return [self[i] for i in range(self.num_activations)] + + +class ActivationLoader: + @classmethod + def all_versions(cls, dir: str | Path): + dir = Path(dir) + + # If the dir is not an activation base directory, we return None + if not (dir / ".activationbase").exists(): + return None + + # The contents of `dir` should be directories, each of which is a version. + return [ + (subdir, int(subdir.name)) for subdir in dir.iterdir() if subdir.is_dir() + ] + + @classmethod + def is_valid_activation_base(cls, dir: str | Path): + return cls.all_versions(dir) is not None + + @classmethod + def from_latest_version(cls, dir: str | Path): + # The contents of `dir` should be directories, each of which is a version + # We need to find the latest version + if (all_versions := cls.all_versions(dir)) is None: + raise ValueError(f"{dir} is not an activation base directory") + + path, _ = max(all_versions, key=lambda p: p[1]) + return cls(path) + + def __init__(self, dir: Path): + self._dir = dir + + def activation(self, name: str): + return LoadedActivation(self._dir, name) + + @cached_property + def activations(self): + return { + p.name: LoadedActivation(self._dir, p.name) for p in self._dir.iterdir() + } + + def __iter__(self): + return iter(self.activations.values()) + + def __getitem__(self, item: str): + return self.activations[item] + + def __len__(self): + return len(self.activations) + + @override + def __repr__(self): + return f"ActivationLoader(dir={self._dir}, activations={list(self.activations.values())})" + + +ActSave = ActSaveProvider() + + +def _wrap_fn(module: LightningModule, fn_name: str): + old_step = getattr(module, fn_name).__func__ + + @wraps(old_step) + def new_step(module: LightningModule, batch, batch_idx, *args, **kwargs): + with ActSave.context(fn_name): + return old_step(module, batch, batch_idx, *args, **kwargs) + + setattr(module, fn_name, new_step.__get__(module)) + log.info(f"Wrapped {fn_name} for actsave") + + +def wrap_lightning_module(module: LightningModule): + log.info( + "Wrapping training_step/validation_step/test_step/predict_step for actsave" + ) + + _wrap_fn(module, "training_step") + _wrap_fn(module, "validation_step") + _wrap_fn(module, "test_step") + _wrap_fn(module, "predict_step") + + log.info("Wrapped training_step/validation_step/test_step/predict_step for actsave") diff --git a/src/jmp/lightning/callbacks/__init__.py b/src/jmp/lightning/callbacks/__init__.py new file mode 100644 index 0000000..85cf48f --- /dev/null +++ b/src/jmp/lightning/callbacks/__init__.py @@ -0,0 +1,19 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .bad_gradients import PrintBadGradientsCallback +from .ema import EMA +from .interval import EpochIntervalCallback, IntervalCallback, StepIntervalCallback + +__all__ = [ + "PrintBadGradientsCallback", + "EMA", + "EpochIntervalCallback", + "IntervalCallback", + "StepIntervalCallback", +] diff --git a/src/jmp/lightning/callbacks/bad_gradients.py b/src/jmp/lightning/callbacks/bad_gradients.py new file mode 100644 index 0000000..0e69a9e --- /dev/null +++ b/src/jmp/lightning/callbacks/bad_gradients.py @@ -0,0 +1,63 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger + +import torch +from lightning.pytorch import Callback, LightningModule, Trainer +from typing_extensions import override + +log = getLogger(__name__) + + +def print_bad_gradients( + module: LightningModule, + nonfinite_grads: bool = True, + none_grads: bool = False, +): + for name, param in module.named_parameters(): + if not param.requires_grad: + continue + + if param.grad is None: + if none_grads: + log.critical(f"Parameter {name} ({param.shape}) has None gradients") + continue + + if not nonfinite_grads or torch.isfinite(param.grad.float()).all(): + continue + + has_nan = torch.isnan(param.grad.float()).any() + has_inf = torch.isinf(param.grad.float()).any() + kinds = [ + "NaN" if has_nan else None, + "Inf" if has_inf else None, + ] + kinds = " and ".join(prop for prop in kinds if prop is not None) + log.critical(f"{name} ({param.shape}) has {kinds} gradients") + + +class PrintBadGradientsCallback(Callback): + def __init__( + self, + *, + nonfinite_grads: bool = True, + none_grads: bool = False, + ): + super().__init__() + + self._nonfinite_grads = nonfinite_grads + self._none_grads = none_grads + + @override + def on_after_backward(self, _trainer: Trainer, module: LightningModule): + print_bad_gradients( + module, + nonfinite_grads=self._nonfinite_grads, + none_grads=self._none_grads, + ) diff --git a/src/jmp/lightning/callbacks/ema.py b/src/jmp/lightning/callbacks/ema.py new file mode 100644 index 0000000..5e87851 --- /dev/null +++ b/src/jmp/lightning/callbacks/ema.py @@ -0,0 +1,353 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib +import copy +import threading +from typing import Iterable + +import lightning.pytorch as pl +import torch +from lightning.pytorch import Callback +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from typing_extensions import override + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. + cpu_offload: Offload weights to CPU. + """ + + @override + def __init__( + self, + decay: float, + validate_original_weights: bool = False, + every_n_steps: int = 1, + cpu_offload: bool = False, + ): + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self.decay = decay + self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps + self.cpu_offload = cpu_offload + + @override + def on_fit_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + device = pl_module.device if not self.cpu_offload else torch.device("cpu") + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) + ] + + @override + def on_validation_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + @override + def on_validation_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + @override + def on_test_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + @override + def on_test_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + return any( + isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers + ) + + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """ + Saves an EMA copy of the model + EMA optimizer states for resume. + """ + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) + + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False + + +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, + current_model_tuple, + alpha=(1.0 - decay), + ) + + +def run_ema_update_cpu( + ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None +): + if pre_sync_stream is not None: + pre_sync_stream.synchronize() + + ema_update(ema_model_tuple, current_model_tuple, decay) + + +class EMAOptimizer(torch.optim.Optimizer): + r""" + EMAOptimizer is a wrapper for torch.optim.Optimizer that computes + Exponential Moving Average of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + @override + def __init__( + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + self.in_saving_ema_model_context = False + + def all_parameters(self) -> Iterable[torch.Tensor]: + return (param for group in self.param_groups for param in group["params"]) + + @override + def step(self, closure=None, **kwargs): + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) + for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + loss = self.optimizer.step(closure) + + if self._should_update_at_step(): + self.update() + self.current_step += 1 + return loss + + def _should_update_at_step(self) -> bool: + return self.current_step % self.every_n_steps == 0 + + @torch.no_grad() + def update(self): + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) + for param in self.all_parameters() + ) + + if self.device.type == "cuda": + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == "cpu": + self.thread = threading.Thread( + target=run_ema_update_cpu, + args=( + self.ema_params, + current_model_state, + self.decay, + self.stream, + ), + ) + self.thread.start() + + def swap_tensors(self, tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self, saving_ema_model: bool = False): + self.join() + self.in_saving_ema_model_context = saving_ema_model + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r""" + A context manager to in-place swap regular parameters with EMA + parameters. + It swaps back to the original regular parameters on context manager + exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + if enabled: + self.switch_main_parameter_weights() + try: + yield + finally: + if enabled: + self.switch_main_parameter_weights() + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def join(self): + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + @override + def state_dict(self): + self.join() + + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = ( + self.ema_params + if not self.in_saving_ema_model_context + else list(self.all_parameters()) + ) + state_dict = { + "opt": self.optimizer.state_dict(), + "ema": ema_params, + "current_step": self.current_step, + "decay": self.decay, + "every_n_steps": self.every_n_steps, + } + return state_dict + + @override + def load_state_dict(self, state_dict): + self.join() + + self.optimizer.load_state_dict(state_dict["opt"]) + self.ema_params = tuple( + param.to(self.device) for param in copy.deepcopy(state_dict["ema"]) + ) + self.current_step = state_dict["current_step"] + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + self.rebuild_ema_params = False + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True diff --git a/src/jmp/lightning/callbacks/interval.py b/src/jmp/lightning/callbacks/interval.py new file mode 100644 index 0000000..117209b --- /dev/null +++ b/src/jmp/lightning/callbacks/interval.py @@ -0,0 +1,249 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Callable, Literal + +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from typing_extensions import override + +Split = Literal["train", "val", "test", "predict"] + + +def _check_step(step: int, interval: int, skip_first: bool = False): + if step % interval != 0: + return False + if skip_first and step == 0: + return False + return True + + +class StepIntervalCallback(Callback): + def __init__( + self, + function: Callable[[Trainer, LightningModule], None], + *, + interval: int, + skip_first: bool = False, + splits: list[Split] = ["train", "val", "test", "predict"], + ): + super().__init__() + + self.function = function + self.interval = interval + self.skip_first = skip_first + self.splits = set(splits) + + @override + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if ( + not _check_step( + trainer.global_step, + self.interval, + skip_first=self.skip_first, + ) + or "train" not in self.splits + ): + return + self.function(trainer, pl_module) + + @override + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if ( + not _check_step( + trainer.global_step, + self.interval, + skip_first=self.skip_first, + ) + or "val" not in self.splits + ): + return + self.function(trainer, pl_module) + + @override + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx): + if ( + not _check_step( + trainer.global_step, + self.interval, + skip_first=self.skip_first, + ) + or "test" not in self.splits + ): + return + self.function(trainer, pl_module) + + @override + def on_predict_batch_start(self, trainer, pl_module, batch, batch_idx): + if ( + not _check_step( + trainer.global_step, + self.interval, + skip_first=self.skip_first, + ) + or "predict" not in self.splits + ): + return + self.function(trainer, pl_module) + + +class EpochIntervalCallback(Callback): + def __init__( + self, + function: Callable[[Trainer, LightningModule], None], + *, + interval: int, + skip_first: bool = False, + splits: list[Split] = ["train", "val", "test", "predict"], + ): + super().__init__() + + self.function = function + self.interval = interval + self.skip_first = skip_first + self.splits = set(splits) + + @override + def on_train_epoch_start(self, trainer, pl_module): + if ( + not _check_step( + trainer.current_epoch, + self.interval, + skip_first=self.skip_first, + ) + or "train" not in self.splits + ): + return + self.function(trainer, pl_module) + + @override + def on_validation_epoch_start(self, trainer, pl_module): + if ( + not _check_step( + trainer.current_epoch, + self.interval, + skip_first=self.skip_first, + ) + or "val" not in self.splits + ): + return + self.function(trainer, pl_module) + + @override + def on_test_epoch_start(self, trainer, pl_module): + if ( + not _check_step( + trainer.current_epoch, + self.interval, + skip_first=self.skip_first, + ) + or "test" not in self.splits + ): + return + self.function(trainer, pl_module) + + @override + def on_predict_epoch_start(self, trainer, pl_module): + if ( + not _check_step( + trainer.current_epoch, + self.interval, + skip_first=self.skip_first, + ) + or "predict" not in self.splits + ): + return + self.function(trainer, pl_module) + + +class IntervalCallback(Callback): + def __init__( + self, + function: Callable[[Trainer, LightningModule], None], + *, + step_interval: int | None = None, + epoch_interval: int | None = None, + skip_first: bool = False, + splits: list[Split] = ["train", "val", "test", "predict"], + ): + super().__init__() + + self.callback = None + + if step_interval is not None: + self.callback = StepIntervalCallback( + function, + interval=step_interval, + splits=splits, + skip_first=skip_first, + ) + elif epoch_interval is not None: + self.callback = EpochIntervalCallback( + function, + interval=epoch_interval, + splits=splits, + skip_first=skip_first, + ) + else: + raise ValueError("Either step_interval or epoch_interval must be specified") + + @override + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if not isinstance(self.callback, StepIntervalCallback): + return + + self.callback.on_train_batch_start(trainer, pl_module, batch, batch_idx) + + @override + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if not isinstance(self.callback, StepIntervalCallback): + return + + self.callback.on_validation_batch_start(trainer, pl_module, batch, batch_idx) + + @override + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx): + if not isinstance(self.callback, StepIntervalCallback): + return + + self.callback.on_test_batch_start(trainer, pl_module, batch, batch_idx) + + @override + def on_predict_batch_start(self, trainer, pl_module, batch, batch_idx): + if not isinstance(self.callback, StepIntervalCallback): + return + + self.callback.on_predict_batch_start(trainer, pl_module, batch, batch_idx) + + @override + def on_train_epoch_start(self, trainer, pl_module): + if not isinstance(self.callback, EpochIntervalCallback): + return + + self.callback.on_train_epoch_start(trainer, pl_module) + + @override + def on_validation_epoch_start(self, trainer, pl_module): + if not isinstance(self.callback, EpochIntervalCallback): + return + + self.callback.on_validation_epoch_start(trainer, pl_module) + + @override + def on_test_epoch_start(self, trainer, pl_module): + if not isinstance(self.callback, EpochIntervalCallback): + return + + self.callback.on_test_epoch_start(trainer, pl_module) + + @override + def on_predict_epoch_start(self, trainer, pl_module): + if not isinstance(self.callback, EpochIntervalCallback): + return + + self.callback.on_predict_epoch_start(trainer, pl_module) diff --git a/src/jmp/lightning/config.py b/src/jmp/lightning/config.py new file mode 100644 index 0000000..311b36a --- /dev/null +++ b/src/jmp/lightning/config.py @@ -0,0 +1,301 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Mapping, MutableMapping +from typing import TYPE_CHECKING, Any, ClassVar + +from pydantic import BaseModel, ConfigDict +from pydantic import Field as Field +from pydantic import PrivateAttr as PrivateAttr +from typing_extensions import override + +from ._config.missing import MISSING, validate_no_missing_values +from ._config.missing import AllowMissing as AllowMissing +from ._config.missing import MissingField as MissingField + +_MutableMappingBase = MutableMapping[str, Any] +if TYPE_CHECKING: + _MutableMappingBase = object + + +_DraftConfigContextSentinel = object() + + +class TypedConfig(BaseModel, _MutableMappingBase): + _is_draft_config: bool = PrivateAttr(default=False) + """ + Whether this config is a draft config or not. + + Draft configs are configs that are not yet fully validated. + They allow for a nicer API when creating configs, e.g.: + + ```python + config = MyConfig.draft() + + # Set some values + config.a = 10 + config.b = "hello" + + # Finalize the config + config = config.finalize() + ``` + """ + + repr_diff_only: ClassVar[bool] = True + """ + If `True`, the repr methods will only show values for fields that are different from the default. + """ + + MISSING: ClassVar[Any] = MISSING + """ + Alias for the `MISSING` constant. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict( + # By default, Pydantic will throw a warning if a field starts with "model_", + # so we need to disable that warning (beacuse "model_" is a popular prefix for ML). + protected_namespaces=(), + validate_assignment=True, + strict=True, + revalidate_instances="always", + arbitrary_types_allowed=True, + extra="ignore", + ) + + def __draft_pre_init__(self): + """Called right before a draft config is finalized.""" + pass + + def __post_init__(self): + """Called after the final config is validated.""" + pass + + @classmethod + def from_dict(cls, model_dict: Mapping[str, Any]): + return cls.model_validate(model_dict) + + def model_deep_validate(self, strict: bool = True): + """ + Validate the config and all of its sub-configs. + + Args: + config: The config to validate. + strict: Whether to validate the config strictly. + """ + config_dict = self.model_dump(round_trip=True) + config = self.model_validate(config_dict, strict=strict) + + # Make sure that this is not a draft config + if config._is_draft_config: + raise ValueError("Draft configs are not valid. Call `finalize` first.") + + return config + + @classmethod + def draft(cls, **kwargs): + config = cls.model_construct_draft(**kwargs) + return config + + def finalize(self, strict: bool = True): + # This must be a draft config, otherwise we raise an error + if not self._is_draft_config: + raise ValueError("Finalize can only be called on drafts.") + + # First, we call `__draft_pre_init__` to allow the config to modify itself a final time + self.__draft_pre_init__() + + # Then, we dump the config to a dict and then re-validate it + return self.model_deep_validate(strict=strict) + + @override + def model_post_init(self, __context: Any) -> None: + super().model_post_init(__context) + + # Call the `__post_init__` method if this is not a draft config + if __context is _DraftConfigContextSentinel: + return + + self.__post_init__() + + # After `_post_init__` is called, we perform the final round of validation + self.model_post_init_validate() + + def model_post_init_validate(self): + validate_no_missing_values(self) + + @classmethod + def model_construct_draft(cls, _fields_set: set[str] | None = None, **values: Any): + """ + NOTE: This is a copy of the `model_construct` method from Pydantic's `Model` class, + with the following changes: + - The `model_post_init` method is called with the `_DraftConfigContext` context. + - The `_is_draft_config` attribute is set to `True` in the `values` dict. + + Creates a new instance of the `Model` class with validated data. + + Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + + !!! note + `model_construct()` generally respects the `model_config.extra` setting on the provided model. + That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__` + and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored. + Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in + an error if extra values are passed, but they will be ignored. + + Args: + _fields_set: The set of field names accepted for the Model instance. + values: Trusted or pre-validated data dictionary. + + Returns: + A new instance of the `Model` class with validated data. + """ + + values["_is_draft_config"] = True + + m = cls.__new__(cls) + fields_values: dict[str, Any] = {} + fields_set = set() + + for name, field in cls.model_fields.items(): + if field.alias and field.alias in values: + fields_values[name] = values.pop(field.alias) + fields_set.add(name) + elif name in values: + fields_values[name] = values.pop(name) + fields_set.add(name) + elif not field.is_required(): + fields_values[name] = field.get_default(call_default_factory=True) + if _fields_set is None: + _fields_set = fields_set + + _extra: dict[str, Any] | None = None + if cls.model_config.get("extra") == "allow": + _extra = {} + for k, v in values.items(): + _extra[k] = v + object.__setattr__(m, "__dict__", fields_values) + object.__setattr__(m, "__pydantic_fields_set__", _fields_set) + if not cls.__pydantic_root_model__: + object.__setattr__(m, "__pydantic_extra__", _extra) + + if cls.__pydantic_post_init__: + m.model_post_init(_DraftConfigContextSentinel) + # update private attributes with values set + if ( + hasattr(m, "__pydantic_private__") + and m.__pydantic_private__ is not None + ): + for k, v in values.items(): + if k in m.__private_attributes__: + m.__pydantic_private__[k] = v + + elif not cls.__pydantic_root_model__: + # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist + # Since it doesn't, that means that `__pydantic_private__` should be set to None + object.__setattr__(m, "__pydantic_private__", None) + + return m + + @override + def __repr_args__(self): + # If `repr_diff_only` is `True`, we only show the fields that are different from the default. + if not self.repr_diff_only: + yield from super().__repr_args__() + return + + # First, we get the default values for all fields. + default_values = self.model_construct_draft() + + # Then, we compare the default values with the current values. + for k, v in super().__repr_args__(): + if k is None: + yield k, v + continue + + # If there is no default value or the value is different from the default, we yield it. + if not hasattr(default_values, k) or getattr(default_values, k) != v: + yield k, v + continue + + # Otherwise, we can skip this field. + + # region MutableMapping implementation + if not TYPE_CHECKING: + # This is mainly so the config can be used with lightning's hparams + # transparently and without any issues. + + @property + def _ll_dict(self): + return self.model_dump() + + # We need to make sure every config class + # is a MutableMapping[str, Any] so that it can be used + # with lightning's hparams. + @override + def __getitem__(self, key: str): + # Key can be of the format "a.b.c" + # so we need to split it into a list of keys. + [first_key, *rest_keys] = key.split(".") + value = self._ll_dict[first_key] + + for key in rest_keys: + if isinstance(value, Mapping): + value = value[key] + else: + value = getattr(value, key) + + return value + + @override + def __setitem__(self, key: str, value: Any): + # Key can be of the format "a.b.c" + # so we need to split it into a list of keys. + [first_key, *rest_keys] = key.split(".") + if len(rest_keys) == 0: + self._ll_dict[first_key] = value + return + + # We need to traverse the keys until we reach the last key + # and then set the value + current_value = self._ll_dict[first_key] + for key in rest_keys[:-1]: + if isinstance(current_value, Mapping): + current_value = current_value[key] + else: + current_value = getattr(current_value, key) + + # Set the value + if isinstance(current_value, MutableMapping): + current_value[rest_keys[-1]] = value + else: + setattr(current_value, rest_keys[-1], value) + + @override + def __delitem__(self, key: str): + # This is unsupported for this class + raise NotImplementedError + + @override + def __iter__(self): + return iter(self._ll_dict) + + @override + def __len__(self): + return len(self._ll_dict) + + # endregion + + +__all__ = [ + "TypedConfig", + "Field", + "PrivateAttr", + "AllowMissing", + "MissingField", +] diff --git a/src/jmp/lightning/data/__init__.py b/src/jmp/lightning/data/__init__.py new file mode 100644 index 0000000..6d9a216 --- /dev/null +++ b/src/jmp/lightning/data/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from . import transform as dataset_transform +from .balanced_batch_sampler import BalancedBatchSampler + +__all__ = [ + "BalancedBatchSampler", + "dataset_transform", +] diff --git a/src/jmp/lightning/data/balanced_batch_sampler.py b/src/jmp/lightning/data/balanced_batch_sampler.py new file mode 100644 index 0000000..1a5ed5c --- /dev/null +++ b/src/jmp/lightning/data/balanced_batch_sampler.py @@ -0,0 +1,140 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import heapq +from functools import cached_property +from logging import getLogger +from typing import Any, Protocol, runtime_checkable + +import numpy as np +import torch +import torch.distributed +from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper +from torch.utils.data import BatchSampler, Dataset, DistributedSampler +from typing_extensions import override + +log = getLogger(__name__) + + +def _all_gather(tensor: torch.Tensor, device: torch.device | None = None): + gathered = [ + torch.zeros_like(tensor, device=device) + for _ in range(torch.distributed.get_world_size()) + ] + _ = torch.distributed.all_gather(gathered, tensor) + return gathered + + +# @numba.njit +def _balanced_partition(sizes: np.ndarray, num_parts: int): + """ + Greedily partition the given set by always inserting + the largest element into the smallest partition. + """ + sort_idx = np.argsort(-sizes) # Sort in descending order + heap = [] + for idx in sort_idx[:num_parts]: + heap.append((sizes[idx], [idx])) + heapq.heapify(heap) + for idx in sort_idx[num_parts:]: + smallest_part = heapq.heappop(heap) + new_size = smallest_part[0] + sizes[idx] + new_idx = smallest_part[1] + [idx] + heapq.heappush(heap, (new_size, new_idx)) + idx_balanced = [part[1] for part in heap] + return idx_balanced + + +@runtime_checkable +class DatasetWithSizes(Protocol): + def data_sizes(self, indices: list[int]) -> np.ndarray: ... + + +class BalancedBatchSampler(BatchSampler): + @staticmethod + def _ensure_supported(dataset: Any): + if not isinstance(dataset, Dataset): + raise ValueError( + "BalancedBatchSampler requires a dataset that implements `__getitem__`" + ) + + if not isinstance(dataset, DatasetWithSizes): + raise ValueError( + "BalancedBatchSampler requires a dataset that implements `data_sizes`" + ) + + log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}") + return dataset + + @staticmethod + def _unwrap_dataset(dataset: Dataset) -> Dataset: + if isinstance(dataset, _DatasetSamplerWrapper): + if (data_source := getattr(dataset._sampler, "data_source", None)) is None: + raise ValueError("Could not unwrap dataset from _DatasetSamplerWrapper") + return data_source + return dataset + + @property + def distributed_sampler(self): + if not isinstance(self.sampler, DistributedSampler): + raise ValueError( + f"Sampler must be a DistributedSampler, got {type(self.sampler)}" + ) + return self.sampler + + @cached_property + def dataset(self): + return self._ensure_supported( + self._unwrap_dataset(self.distributed_sampler.dataset) + ) + + def __init__( + self, + sampler: DistributedSampler, + *, + batch_size: int, + device: torch.device, + drop_last: bool = False, + ): + super().__init__(sampler, batch_size, drop_last=drop_last) + + self._device = device + + log.info( + f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}" + ) + + @staticmethod + def _dist_enabled(): + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @override + def __iter__(self): + if not self._dist_enabled(): + yield from super().__iter__() + return + + for batch_idx in super().__iter__(): + sizes = self.dataset.data_sizes(batch_idx) + idx_sizes = torch.stack( + [ + torch.tensor(batch_idx, device=self._device), + torch.tensor(sizes, device=self._device), + ] + ) + idx_sizes_all = _all_gather(idx_sizes, device=self._device) + idx_sizes_all = torch.cat(idx_sizes_all, dim=-1).cpu() + idx_all = idx_sizes_all[0] + sizes_all = idx_sizes_all[1] + + local_idx_balanced = _balanced_partition( + sizes_all.numpy(), num_parts=self.distributed_sampler.num_replicas + ) + # Since DistributedSampler pads the last batch + # this should always have an entry for each replica. + yield idx_all[local_idx_balanced[self.distributed_sampler.rank]].tolist() diff --git a/src/jmp/lightning/data/transform.py b/src/jmp/lightning/data/transform.py new file mode 100644 index 0000000..21ab696 --- /dev/null +++ b/src/jmp/lightning/data/transform.py @@ -0,0 +1,58 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from typing import Any, Callable, cast + +import wrapt +from typing_extensions import TypeVar, override + +TDataset = TypeVar("TDataset", infer_variance=True) + + +def transform( + dataset: TDataset, + transform: Callable[[Any], Any], + *, + deepcopy: bool = False, +) -> TDataset: + class _TransformedDataset(wrapt.ObjectProxy): + @override + def __getitem__(self, idx): + nonlocal deepcopy, transform + + data = self.__wrapped__.__getitem__(idx) + if deepcopy: + data = copy.deepcopy(data) + data = transform(data) + return data + + return cast(TDataset, _TransformedDataset(dataset)) + + +def transform_with_index( + dataset: TDataset, + transform: Callable[[Any, int], Any], + *, + deepcopy: bool = False, +) -> TDataset: + class _TransformedWithIndexDataset(wrapt.ObjectProxy): + @override + def __getitem__(self, idx: int): + nonlocal deepcopy, transform + + data = self.__wrapped__.__getitem__(idx) + if deepcopy: + data = copy.deepcopy(data) + data = transform(data, idx) + return data + + return cast(TDataset, _TransformedWithIndexDataset(dataset)) + + +__all__ = ["transform"] diff --git a/src/jmp/lightning/exception.py b/src/jmp/lightning/exception.py new file mode 100644 index 0000000..09d9079 --- /dev/null +++ b/src/jmp/lightning/exception.py @@ -0,0 +1,58 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Any + +from typing_extensions import override + + +class SkipBatch(BaseException): + """Exception to skip the current batch.""" + + +class TrainingError(Exception): + """Exception thrown during training which contains information about the batch that caused the error. + + Args: + exception: The exception that was thrown. + batch_idx: The index of the batch that caused the error. + batch: The batch that caused the error. + epoch: The epoch that the error occurred in. + global_step: The global step that the error occurred in. + training_fn: The training function that the error occurred in (one of "training_fn", "validation_step", "test_step", "predict_step"). + """ + + @override + def __init__( + self, + exception: BaseException, + *, + batch_idx: int, + batch: Any, + epoch: int, + global_step: int, + training_fn: str, + ): + self.exception = exception + self.batch_idx = batch_idx + self.batch = batch + self.epoch = epoch + self.global_step = global_step + self.training_fn = training_fn + + super().__init__( + f"Training error in training_fn {training_fn} at epoch {epoch} and global step {global_step} at batch {batch_idx}." + ) + + @override + def __repr__(self): + return f"TrainingError(batch_idx={self.batch_idx}, batch={self.batch}, epoch={self.epoch}, global_step={self.global_step}, training_fn={self.training_fn})" + + @override + def __str__(self): + return self.__repr__() diff --git a/src/jmp/lightning/local_sessions_runner.py b/src/jmp/lightning/local_sessions_runner.py new file mode 100644 index 0000000..c1f43a1 --- /dev/null +++ b/src/jmp/lightning/local_sessions_runner.py @@ -0,0 +1,65 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import logging +from pathlib import Path + +import cloudpickle as pickle + +log = logging.getLogger(__name__) + + +def process_session(path: Path) -> None: + log.critical(f"Executing {path}") + # Load the path pickle. It should be a tuple of (run_fn, runner_kwargs, config) + with path.open("rb") as file: + loaded = pickle.load(file) + + if not isinstance(loaded, tuple): + raise TypeError(f"Expected a tuple, got {type(loaded)}") + + if not len(loaded) == 3: + raise ValueError(f"Expected a tuple of length 3, got {len(loaded)}") + + run_fn, runner_kwargs, config = loaded + assert callable(run_fn), f"Expected a callable, got {type(run_fn)}" + assert isinstance( + runner_kwargs, dict + ), f"Expected a dict, got {type(runner_kwargs)}" + + # Execute the run_fn + from jmp.lightning.runner import Runner + + runner = Runner(run_fn, **runner_kwargs) + _ = runner([config]) + log.critical(f"Executed {path}") + + +def main(): + parser = argparse.ArgumentParser() + _ = parser.add_argument( + "paths", nargs="+", type=Path, help="Paths to the sessions to run" + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + + if not args.paths: + raise ValueError("No paths provided") + + log.critical(f"Executing {args.paths=}") + + for path in args.paths: + process_session(path) + + log.critical("All sessions executed") + + +if __name__ == "__main__": + main() diff --git a/src/jmp/lightning/model/base.py b/src/jmp/lightning/model/base.py new file mode 100644 index 0000000..9c08041 --- /dev/null +++ b/src/jmp/lightning/model/base.py @@ -0,0 +1,382 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import inspect +import json +import os +import sys +from abc import ABC, abstractmethod +from logging import getLogger +from typing import Any, Callable, Generic, cast + +import torch +import torch.nn as nn +from lightning.pytorch import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from typing_extensions import TypeVar, override + +from .. import actsave +from ..trainer import Trainer as LLTrainer +from ..util import log_batch_info, skip_batch +from .config import BaseConfig +from .modules.callback import CallbackModuleMixin, CallbackRegistrarModuleMixin +from .modules.debug import DebugModuleMixin +from .modules.distributed import DistributedMixin +from .modules.finite_checks import FiniteChecksModuleMixin +from .modules.log_dir import LogDirMixin +from .modules.log_epoch import LogEpochMixin +from .modules.logger import LoggerModuleMixin +from .modules.lr_monitor import LRMonitorMixin +from .modules.optimizer import OptimizerModuleMixin +from .modules.parameter_hooks import ParameterHookModuleMixin +from .modules.profiler import ProfilerMixin +from .modules.rlp_sanity_checks import RLPSanityCheckModuleMixin +from .modules.shared_parameters import SharedParametersModuleMixin +from .modules.wandb import WandbWrapperMixin + +log = getLogger(__name__) + +THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True) + + +class _ResidualSequential(nn.Sequential): + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + super().forward(x) + + +class Base(DebugModuleMixin, Generic[THparams]): + @torch.jit.unused + def mlp( + self, + dims: list[int], + *, + activation: Callable[[], nn.Module], + bias: bool = True, + no_bias_scalar: bool = True, + ln: bool | str = False, + dropout: float | None = None, + residual: bool = False, + pre_layers: list[nn.Module] = [], + post_layers: list[nn.Module] = [], + ) -> nn.Sequential: + """ + Constructs a multi-layer perceptron (MLP) with the given dimensions and activation function. + + Args: + dims (list[int]): List of integers representing the dimensions of the MLP. + activation (Callable[[], nn.Module]): Activation function to use between layers. + bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True. + no_bias_scalar (bool, optional): Whether to exclude bias terms when the output dimension is 1. Defaults to True. + ln (bool | str, optional): Whether to apply layer normalization before or after the linear layers. Defaults to False. + dropout (float | None, optional): Dropout probability to apply between layers. Defaults to None. + residual (bool, optional): Whether to use residual connections between layers. Defaults to False. + pre_layers (list[nn.Module], optional): List of layers to insert before the linear layers. Defaults to []. + post_layers (list[nn.Module], optional): List of layers to insert after the linear layers. Defaults to []. + + Returns: + nn.Sequential: The constructed MLP. + """ + + if len(dims) < 2: + raise ValueError("mlp requires at least 2 dimensions") + if ln is True: + ln = "pre" + elif isinstance(ln, str) and ln not in ("pre", "post"): + raise ValueError("ln must be a boolean or 'pre' or 'post'") + + layers: list[nn.Module] = [] + if ln == "pre": + layers.append(nn.LayerNorm(dims[0])) + + layers.extend(pre_layers) + + for i in range(len(dims) - 1): + in_features = dims[i] + out_features = dims[i + 1] + bias_ = bias and not (no_bias_scalar and out_features == 1) + layers.append(nn.Linear(in_features, out_features, bias=bias_)) + if dropout is not None: + layers.append(nn.Dropout(dropout)) + if i < len(dims) - 2: + layers.append(activation()) + + layers.extend(post_layers) + + if ln == "post": + layers.append(nn.LayerNorm(dims[-1])) + + cls = _ResidualSequential if residual else nn.Sequential + return cls(*layers) + + @torch.jit.unused + @property + def config(self) -> THparams: + return self.hparams + + @torch.jit.unused + @property + def C(self) -> THparams: + return self.hparams + + @property + def debug(self) -> bool: + if torch.jit.is_scripting(): + return False + return self.config.debug + + @property + def dev(self) -> bool: + if torch.jit.is_scripting(): + return False + return self.config.debug + + @override + def __init__(self, hparams: THparams): + super().__init__() + + if not hasattr(self, "hparams"): + self.hparams = hparams + + +class DebugFlagCallback(Callback): + """ + Sets the debug flag to true in the following circumstances: + - fast_dev_run is enabled + - sanity check is running + """ + + @override + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str): + if not getattr(trainer, "fast_dev_run", False): + return + + hparams = cast(BaseConfig, pl_module.hparams) + if not hparams.debug: + log.critical("Fast dev run detected, setting debug flag to True.") + hparams.debug = True + + @override + def on_sanity_check_start(self, trainer: Trainer, pl_module: LightningModule): + hparams = cast(BaseConfig, pl_module.hparams) + self._debug = hparams.debug + if not self._debug: + log.critical("Enabling debug flag during sanity check routine.") + hparams.debug = True + + @override + def on_sanity_check_end(self, trainer: Trainer, pl_module: LightningModule): + hparams = cast(BaseConfig, pl_module.hparams) + if not self._debug: + log.critical("Sanity check routine complete, disabling debug flag.") + hparams.debug = self._debug + + +def _slurm_session_info(): + try: + from submitit import JobEnvironment + + job = JobEnvironment() + if not job.activated(): + return {} + + return { + "hostname": job.hostname, + "hostnames": job.hostnames, + "job_id": job.job_id, + "raw_job_id": job.raw_job_id, + "array_job_id": job.array_job_id, + "array_task_id": job.array_task_id, + "num_tasks": job.num_tasks, + "num_nodes": job.num_nodes, + "node": job.node, + "global_rank": job.global_rank, + "local_rank": job.local_rank, + } + except (ImportError, RuntimeError): + return {} + + +def _cls_info(cls: type): + name = cls.__name__ + module = cls.__module__ + full_name = f"{cls.__module__}.{cls.__qualname__}" + + file_path = inspect.getfile(cls) + source_file_path = inspect.getsourcefile(cls) + + return { + "name": name, + "module": module, + "full_name": full_name, + "file_path": file_path, + "source_file_path": source_file_path, + } + + +class LightningModuleBase( + ProfilerMixin, + LogDirMixin, + WandbWrapperMixin, + OptimizerModuleMixin, + RLPSanityCheckModuleMixin, + LogEpochMixin, + LoggerModuleMixin, + LRMonitorMixin, + FiniteChecksModuleMixin, + SharedParametersModuleMixin, + ParameterHookModuleMixin, + DistributedMixin, + CallbackModuleMixin, + Base[THparams], + LightningModule, + ABC, + Generic[THparams], +): + hparams: THparams + hparams_initial: THparams + + @classmethod + @abstractmethod + def config_cls(cls) -> type[THparams]: ... + + @classmethod + def _update_environment(cls, hparams: THparams): + hparams.environment.cwd = os.getcwd() + hparams.environment.python_executable = sys.executable + hparams.environment.python_path = sys.path + hparams.environment.python_version = sys.version + hparams.environment.config = _cls_info(cls.config_cls()) + hparams.environment.model = _cls_info(cls) + hparams.environment.slurm = _slurm_session_info() + hparams.environment.log_dir = str( + hparams.trainer.default_root_dir + or LLTrainer.ll_default_root_dir(hparams).absolute() + ) + hparams.environment.seed = ( + int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None + ) + hparams.environment.seed_workers = ( + bool(int(seed_everything)) + if (seed_everything := os.environ.get("PL_SEED_WORKERS")) + else None + ) + hparams.environment.sweep_id = os.environ.get("LL_WANDB_SWEEP_ID") + hparams.environment.sweep_config = ( + json.loads(config_json) + if (config_json := os.environ.get("LL_WANDB_SWEEP_CONFIG")) is not None + else None + ) + + @override + def __init__(self, hparams: THparams): + if isinstance(hparams, dict): + hparams = self.config_cls().from_dict(hparams) + self._update_environment(hparams) + + super().__init__(hparams) + + self.save_hyperparameters(hparams) + + self.register_callback(lambda: DebugFlagCallback()) + + actsave.wrap_lightning_module(self) + + if self.config.trainer.log_batch_info_on_error: + log_batch_info.wrap_lightning_module(self) + + if self.config.trainer.supports_skip_batch_exception: + skip_batch.wrap_lightning_module(self) + + def zero_loss(self): + """ + Returns a loss tensor with the value 0. + It multiples each weight by 0 and returns the sum, so we don't run into issues with ununsed parameters in DDP. + """ + loss = sum((0.0 * v).sum() for v in self.parameters()) + loss = cast(torch.Tensor, loss) + return loss + + def skip_batch_training_step(self, *args: Any, **kwargs: Any): + """ + This function gets called when a `SkipBatch` exception is raised during any point in training. + If `automatic_optimization` is enabled, it should return a loss tensor that will be used for the backward pass. By default, it returns a zero loss tensor. + If `automatic_optimization` is disabled, this function needs to be implemented and should handle the backward pass itself. + """ + if not self.automatic_optimization: + raise NotImplementedError( + "To use `SkipBatch` with manual optimization, you must implement `skip_batch_training_step`." + ) + + loss = self.zero_loss() + return loss + + @property + def datamodule(self): + datamodule = getattr(self.trainer, "datamodule", None) + if datamodule is None: + return None + + if not isinstance(datamodule, LightningDataModuleBase): + raise TypeError( + f"datamodule must be a LightningDataModuleBase: {type(datamodule)}" + ) + + datamodule = cast(LightningDataModuleBase[THparams], datamodule) + return datamodule + + # @abstractmethod + # @override + # def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + # ... + + # @abstractmethod + # @override + # def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + # ... + + +class LightningDataModuleBase( + LogDirMixin, + CallbackRegistrarModuleMixin, + Base[THparams], + LightningDataModule, + ABC, + Generic[THparams], +): + hparams: THparams + hparams_initial: THparams + + @classmethod + def _update_environment(cls, hparams: THparams): + hparams.environment.data = _cls_info(cls) + + @override + def __init__(self, hparams: THparams): + self._update_environment(hparams) + super().__init__(hparams) + + self.save_hyperparameters(hparams) + + @property + def lightning_module(self): + if not self.trainer: + raise ValueError("Trainer has not been set.") + + module = self.trainer.lightning_module + if not isinstance(module, LightningModuleBase): + raise ValueError( + f"Trainer's lightning_module is not a LightningModuleBase: {type(module)}" + ) + + module = cast(LightningModuleBase[THparams], module) + return module + + @property + def device(self): + return self.lightning_module.device diff --git a/src/jmp/lightning/model/config.py b/src/jmp/lightning/model/config.py new file mode 100644 index 0000000..28e333a --- /dev/null +++ b/src/jmp/lightning/model/config.py @@ -0,0 +1,736 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import string +import time +import warnings +from abc import ABC, abstractmethod +from datetime import timedelta +from logging import getLogger +from pathlib import Path +from typing import ( + Annotated, + Any, + ClassVar, + Literal, + Protocol, + Self, + TypeAlias, + runtime_checkable, +) + +import numpy as np +from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment +from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT +from lightning.pytorch.plugins.layer_sync import LayerSync +from lightning.pytorch.plugins.precision.precision import Precision +from lightning.pytorch.profilers import Profiler +from typing_extensions import TypeVar, deprecated, override + +from ..config import Field, TypedConfig + +logger = getLogger(__name__) + + +class IdSeedWarning(Warning): + pass + + +class BaseProfilerConfig(TypedConfig, ABC): + dirpath: str | Path | None = None + """ + Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`) + will be used. + """ + filename: str | None = None + """ + If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + """ + + @abstractmethod + def construct_profiler(self) -> Profiler: ... + + +class SimpleProfilerConfig(BaseProfilerConfig): + kind: Literal["simple"] = "simple" + + extended: bool = True + """ + If ``True``, adds extra columns representing number of calls and percentage of + total time spent onrespective action. + """ + + @override + def construct_profiler(self): + from lightning.pytorch.profilers.simple import SimpleProfiler + + return SimpleProfiler( + extended=self.extended, + dirpath=self.dirpath, + filename=self.filename, + ) + + +class AdvancedProfilerConfig(BaseProfilerConfig): + kind: Literal["advanced"] = "advanced" + + line_count_restriction: float = 1.0 + """ + This can be used to limit the number of functions + reported for each action. either an integer (to select a count of lines), + or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) + """ + + @override + def construct_profiler(self): + from lightning.pytorch.profilers.advanced import AdvancedProfiler + + return AdvancedProfiler( + line_count_restriction=self.line_count_restriction, + dirpath=self.dirpath, + filename=self.filename, + ) + + +class PyTorchProfilerConfig(BaseProfilerConfig): + kind: Literal["pytorch"] = "pytorch" + + group_by_input_shapes: bool = False + """Include operator input shapes and group calls by shape.""" + + emit_nvtx: bool = False + """ + Context manager that makes every autograd operation emit an NVTX range + Run:: + + nvprof --profile-from-start off -o trace_name.prof -- + + To visualize, you can either use:: + + nvvp trace_name.prof + torch.autograd.profiler.load_nvprof(path) + """ + + export_to_chrome: bool = True + """ + Whether to export the sequence of profiled operators for Chrome. + It will generate a ``.json`` file which can be read by Chrome. + """ + + row_limit: int = 20 + """ + Limit the number of rows in a table, ``-1`` is a special value that + removes the limit completely. + """ + + sort_by_key: str | None = None + """ + Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, + ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. + """ + + record_module_names: bool = True + """Whether to add module names while recording autograd operation.""" + + table_kwargs: dict[str, Any] | None = None + """Dictionary with keyword arguments for the summary table.""" + + additional_profiler_kwargs: dict[str, Any] = {} + """Keyword arguments for the PyTorch profiler. This depends on your PyTorch version""" + + @override + def construct_profiler(self): + from lightning.pytorch.profilers.pytorch import PyTorchProfiler + + return PyTorchProfiler( + group_by_input_shapes=self.group_by_input_shapes, + emit_nvtx=self.emit_nvtx, + export_to_chrome=self.export_to_chrome, + row_limit=self.row_limit, + sort_by_key=self.sort_by_key, + record_module_names=self.record_module_names, + table_kwargs=self.table_kwargs, + dirpath=self.dirpath, + filename=self.filename, + **self.additional_profiler_kwargs, + ) + + +ProfilerConfig: TypeAlias = Annotated[ + SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig, + Field(discriminator="kind"), +] + + +class EnvironmentConfig(TypedConfig): + cwd: str | None = None + + python_executable: str | None = None + python_path: list[str] | None = None + python_version: str | None = None + + config: dict[str, Any] | None = None + model: dict[str, Any] | None = None + data: dict[str, Any] | None = None + + slurm: dict[str, Any] | None = None + + log_dir: str | None = None + + seed: int | None = None + seed_workers: bool | None = None + + sweep_id: str | None = None + sweep_config: dict[str, Any] | None = None + + +class WandbWatchConfig(TypedConfig): + enabled: bool = True + """Enable watching the model for wandb.""" + + log: str | None = None + log_graph: bool = True + log_freq: int = 100 + + +class WandbLoggingConfig(TypedConfig): + enabled: bool = True + """Enable logging to wandb.""" + + log_model: bool | str = False + """ + Whether to log the model checkpoints to wandb. + Valid values are: + - False: Do not log the model checkpoints. + - True: Log the latest model checkpoint. + - "all": Log all model checkpoints. + """ + + watch: WandbWatchConfig = WandbWatchConfig() + """WandB model watch configuration. Used to log model architecture, gradients, and parameters.""" + + +class CSVLoggingConfig(TypedConfig): + enabled: bool = True + """Enable logging to CSV files.""" + + +class TensorboardLoggingConfig(TypedConfig): + enabled: bool = False + """Enable logging to tensorboard.""" + + +class LoggingConfig(TypedConfig): + enabled: bool = True + """Enable logging.""" + + log_lr: bool | Literal["step", "epoch"] = True + """If enabled, will register a `LearningRateMonitor` callback to log the learning rate to the logger.""" + log_epoch: bool = True + """If enabled, will log the fractional epoch number to the logger.""" + + wandb: WandbLoggingConfig = WandbLoggingConfig() + """WandB configuration""" + + csv: CSVLoggingConfig = CSVLoggingConfig() + """CSV configuration""" + + tensorboard: TensorboardLoggingConfig = TensorboardLoggingConfig() + """Tensorboard configuration""" + + +class GradientClippingConfig(TypedConfig): + enabled: bool = True + """Enable gradient clipping.""" + value: int | float + """Value to use for gradient clipping.""" + algorithm: Literal["value", "norm"] = "norm" + """Norm type to use for gradient clipping.""" + + +class GradientSkippingConfig(TypedConfig): + enabled: bool = True + """Enable gradient skipping.""" + norm_type: str | float = 2.0 + """Norm type to use for gradient skipping.""" + threshold: float = float("inf") + """Threshold to use for gradient skipping.""" + start_after_n_steps: int | None = 100 + """Number of steps to wait before starting gradient skipping.""" + + +class OptimizerConfig(TypedConfig): + grad_finite_checks: bool = False + """If enabled, will check that the gradients are finite after each backward pass.""" + grad_none_checks: bool = False + """If enabled, will check that the gradients are not None after each backward pass.""" + + log_grad_norm: bool | str | float = False + """If enabled, will log the gradient norm (averaged across all model parameters) to the logger.""" + log_grad_norm_per_param: bool | str | float = False + """If enabled, will log the gradient norm for each model parameter to the logger.""" + + log_param_norm: bool | str | float = False + """If enabled, will log the parameter norm (averaged across all model parameters) to the logger.""" + log_param_norm_per_param: bool | str | float = False + """If enabled, will log the parameter norm for each model parameter to the logger.""" + + gradient_clipping: GradientClippingConfig | None = None + """Gradient clipping configuration, or None to disable gradient clipping.""" + + gradient_skipping: GradientSkippingConfig | None = None + """Gradient skipping configuration, or None to disable gradient skipping.""" + + +class PythonLogging(TypedConfig): + log_level: ( + Literal["CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"] | None + ) = None + """Log level to use for the Python logger (or None to use the default).""" + + rich: bool = True + """If enabled, will use the rich library to format the Python logger output.""" + rich_tracebacks: bool = True + """If enabled, will use the rich library to format the Python logger tracebacks.""" + + lovely_tensors: bool = True + """If enabled, will use the lovely-tensors library to format PyTorch tensors.""" + lovely_numpy: bool = False + """If enabled, will use the lovely-numpy library to format numpy arrays. False by default as it causes some issues with other libaries.""" + + +TPlugin = TypeVar( + "TPlugin", + Precision, + ClusterEnvironment, + CheckpointIO, + LayerSync, + infer_variance=True, +) + + +@runtime_checkable +class PluginConfigProtocol(Protocol[TPlugin]): + def construct_plugin(self) -> TPlugin: ... + + +class TrainerConfig(TypedConfig): + python_logging: PythonLogging = PythonLogging() + """Python logging configuration options.""" + + logging: LoggingConfig = LoggingConfig() + """Logging (e.g., WandB logging) configuration options.""" + + optimizer: OptimizerConfig = OptimizerConfig() + """Optimizer configuration options.""" + + seed: int | None = 0 + """Seed for the random number generator. If None, will use a random seed.""" + seed_workers: bool = False + """Whether to seed the workers of the dataloader.""" + default_ckpt_path: str | None = None + """Default checkpoint path to use when loading a checkpoint. "last" will load the last checkpoint. "hpc" will load the SLURM pre-empted checkpoint.""" + + auto_wrap_trainer: bool = True + """If enabled, will automatically wrap the `run` function with a `Trainer.context()` context manager. Should be `True` most of the time.""" + auto_set_default_root_dir: bool = True + """If enabled, will automatically set the default root dir to [cwd/lightning_logs//]. Should be `True` most of the time.""" + auto_set_loggers: bool = True + """If enabled, will automatically set the loggers to [WandbLogger, CSVLogger, TensorboardLogger] as defined in `config.logging`. Should be `True` most of the time.""" + checkpoint_last_by_default: bool = True + """If enabled, will update the trainer to save the last checkpoint by default.""" + on_exception_checkpoint: bool = True + """If enabled, will checkpoint the model when an exception is thrown during training.""" + auto_add_trainer_finalizer: bool = True + """If enabled, will automatically finalize the trainer (e.g., call `wandb.finish()`) when the run ends. Should be `True` most of the time.""" + enable_logger_validation: bool = True + """If enabled, will validate loggers. This makes sure that the logger's log_dirs are correct given the current config id. Should be `True` most of the time.""" + + supports_skip_batch_exception: bool = True + """If enabled, the model supports skipping an entire batch by throwing a `SkipBatch` exception.""" + supports_shared_parameters: bool = True + """If enabled, the model supports scaling the gradients of shared parameters that are registered using `LightningModuleBase.register_shared_parameters(...)`""" + supports_parameter_hooks: bool = True + """If enabled, the model supports registering parameter hooks using `LightningModuleBase.register_parameter_hook(...)`""" + log_batch_info_on_error: bool = False + """If enabled, will log the batch info (e.g. batch index, batch object, etc.) when an exception is thrown during training.""" + reduce_lr_on_plateau_sanity_checks: Literal["disable", "error", "warn"] = "error" + """ + Valid values are: "disable", "warn", "error" + If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used: + - If the `interval` is step, it makes sure that validation is called every `frequency` steps. + - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs. + """ + + additional_trainer_kwargs: dict[str, Any] = {} + """Additional keyword arguments to pass to the Lightning `pl.Trainer` constructor.""" + additional_env_vars: dict[str, str] = {} + """Additional environment variables to set when running the trainer.""" + set_nccl_optimal_params: bool = False + """If enabled, will set the NCCL optimal parameters when running on multiple GPUs + nodes.""" + + set_float32_matmul_precision: Literal["medium", "high", "highest"] | None = None + """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs.""" + + accelerator: Literal["cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto"] = "auto" + """ + Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps", "auto") + as well as custom accelerator instances. + """ + strategy: str = "auto" + """ + Supports different training strategies with aliases as well custom strategies. + Default: ``"auto"``. + """ + devices: list[int] | str | int = "auto" + """ + The devices to use. Can be set to a positive number (int or str), a sequence of device indices + (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for + automatic selection based on the chosen accelerator. Default: ``"auto"``. + """ + num_nodes: Literal["auto"] | int = "auto" + """ + Number of GPU nodes for distributed training, + or ``"auto"`` to automatically detect the number of nodes. + """ + precision: _PRECISION_INPUT = "32-true" + """ + Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), + 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). + Can be used on CPU, GPU, TPUs, HPUs or IPUs. + Default: ``'32-true'``. + """ + logger: bool | None = None + """ + Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses + the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. + ``False`` will disable logging. If multiple loggers are provided, local files + (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger. + Default: ``True``. + """ + fast_dev_run: int | bool = False + """ + Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) + of train, val and test to find any bugs (ie: a sort of unit test). + Default: ``False``. + """ + max_epochs: int | None = None + """ + Stop training once this number of epochs is reached. Disabled by default (None). + If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. + To enable infinite training, set ``max_epochs = -1``. + """ + min_epochs: int | None = None + """Force training for at least these many epochs. Disabled by default (None).""" + max_steps: int = -1 + """ + Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` + and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set + ``max_epochs`` to ``-1``. + """ + min_steps: int | None = None + """Force training for at least these number of steps. Disabled by default (``None``).""" + max_time: str | timedelta | dict[str, Any] | None = None + """ + Stop training after this amount of time has passed. Disabled by default (``None``). + The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a + :class:`datetime.timedelta`, or a dictionary with keys that will be passed to + :class:`datetime.timedelta`. + """ + limit_train_batches: int | float | None = None + """ + How much of training dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + """ + limit_val_batches: int | float | None = None + """ + How much of validation dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + """ + limit_test_batches: int | float | None = None + """ + How much of test dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + """ + limit_predict_batches: int | float | None = None + """ + How much of prediction dataset to check (float = fraction, int = num_batches). + Default: ``1.0``. + """ + overfit_batches: int | float = 0.0 + """ + Overfit a fraction of training/validation data (float) or a set number of batches (int). + ``0.0`` means no overfitting. Default: ``0.0``. + """ + val_check_interval: int | float | None = None + """ + How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check + after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training + batches. An ``int`` value can only be higher than the number of training batches when + ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches + across epochs or during iteration-based training. + Default: ``1.0``. + """ + check_val_every_n_epoch: int | None = 1 + """ + Perform a validation loop every after every `N` training epochs. If ``None``, + validation will be done solely based on the number of training batches, requiring ``val_check_interval`` + to be an integer value. + Default: ``1``. + """ + num_sanity_val_steps: int | None = None + """ + Sanity check runs n validation batches before starting the training routine. + Set it to `-1` to run all batches in all validation dataloaders. + Default: ``2``. + """ + log_every_n_steps: int = 50 + """ + How often to log within steps. + Default: ``50``. + """ + enable_checkpointing: bool | None = None + """ + If ``True``, enable checkpointing. + It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`. + Default: ``True``. + """ + enable_progress_bar: bool | None = None + """ + Whether to enable to progress bar by default. + Default: ``True``. + """ + enable_model_summary: bool | None = None + """ + Whether to enable model summarization by default. + Default: ``True``. + """ + accumulate_grad_batches: int = 1 + """ + Accumulates gradients over k batches before stepping the optimizer. + ``1`` means no gradient accumulation (i.e., performs a step after each batch). + Default: ``1``. + """ + deterministic: bool | str | None = None + """ + If ``True``, sets whether PyTorch operations must use deterministic algorithms. + Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations + that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``. + """ + benchmark: bool | None = None + """ + The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. + The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used + (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic` + is set to ``True``, this will default to ``False``. Override to manually set a different value. + Default: ``None``. + """ + inference_mode: bool = True + """ + Whether to use :func:`torch.inference_mode` (if `True`) or :func:`torch.no_grad` (if `False`) during + evaluation (``validate``/``test``/``predict``). + """ + use_distributed_sampler: bool = True + """ + Whether to wrap the DataLoader's sampler with + :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for + strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and + ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass + ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed + sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, + we don't do this automatically. + """ + profiler: str | ProfilerConfig | None = None + """ + To profile individual steps during training and assist in identifying bottlenecks. + Default: ``None``. + """ + detect_anomaly: bool = False + """ + Enable anomaly detection for the autograd engine. + Default: ``False``. + """ + barebones: bool = False + """ + Whether to run in "barebones mode", where all features that may impact raw speed are + disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training + runs. The following features are deactivated: + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`, + :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`, + :meth:`~lightning.pytorch.core.LightningModule.log`, + :meth:`~lightning.pytorch.core.LightningModule.log_dict`. + """ + plugins: list[PluginConfigProtocol] | None = None + """ + Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. + Default: ``None``. + """ + sync_batchnorm: bool = False + """ + Synchronize batch norm layers between process groups/whole world. + Default: ``False``. + """ + reload_dataloaders_every_n_epochs: int = 0 + """ + Set to a positive integer to reload dataloaders every n epochs. + Default: ``0``. + """ + default_root_dir: str | Path | None = None + """ + Default path for logs and weights when no logger/ckpt_callback passed. + Default: ``os.getcwd()``. + Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + """ + + # region Deprecated fields + @property + @deprecated("Please use trainer.optimizer.gradient_clipping instead.") + def automatic_gradient_clip(self): + if (config := self.optimizer.gradient_clipping) is None or not config.enabled: + return False + return True + + @automatic_gradient_clip.setter + @deprecated("Please use trainer.optimizer.gradient_clipping instead.") + def automatic_gradient_clip(self, value: bool): + if self.optimizer.gradient_clipping is None: + self.optimizer.gradient_clipping = GradientClippingConfig( + enabled=False, value=1.0, algorithm="norm" + ) + self.optimizer.gradient_clipping.enabled = value + + @property + @deprecated("Please use trainer.optimizer.gradient_clipping instead.") + def gradient_clip_algorithm(self): + if (config := self.optimizer.gradient_clipping) is None or not config.enabled: + return "norm" + return config.algorithm + + @gradient_clip_algorithm.setter + @deprecated("Please use trainer.optimizer.gradient_clipping instead.") + def gradient_clip_algorithm(self, value: Literal["value", "norm"]): + if self.optimizer.gradient_clipping is None: + self.optimizer.gradient_clipping = GradientClippingConfig( + enabled=False, value=1.0, algorithm=value + ) + self.optimizer.gradient_clipping.algorithm = value + + @property + @deprecated("Please use trainer.optimizer.gradient_clipping instead.") + def gradient_clip_val(self): + if (config := self.optimizer.gradient_clipping) is None or not config.enabled: + return None + return config.value + + @gradient_clip_val.setter + @deprecated("Please use trainer.optimizer.gradient_clipping instead.") + def gradient_clip_val(self, value: int | float | None): + if value is None: + self.optimizer.gradient_clipping = None + return + + if self.optimizer.gradient_clipping is None: + self.optimizer.gradient_clipping = GradientClippingConfig( + enabled=False, value=value, algorithm="norm" + ) + self.optimizer.gradient_clipping.enabled = True + self.optimizer.gradient_clipping.value = value + + # endregion + + +class RunnerOutputSaveConfig(TypedConfig): + enabled: bool = True + """Enable saving the runner stdout and stderr to a file.""" + dirpath: str | Path | None = None + """Directory path for the output file. If None, will use the current working directory/ll_runner_logs/{id}""" + + +class RunnerConfig(TypedConfig): + auto_call_trainer_init_from_runner: bool = True + """If enabled, will automatically call the Trainer.runner_init() function from the Runner. Should be `True` most of the time.""" + save_output: RunnerOutputSaveConfig | None = None + """Output saving configuration options, or ``None`` to disable output saving.""" + + +class BaseConfig(TypedConfig): + id: str = Field(default_factory=lambda: BaseConfig.generate_id()) + """ID of the run.""" + name: str | None = None + """Run name.""" + project: str | None = None + """Project name.""" + tags: list[str] = [] + """Tags for the run.""" + notes: list[str] = [] + """Human readable notes for the run.""" + + debug: bool = False + """Whether to run in debug mode. This will enable debug logging and enable debug code paths.""" + environment: EnvironmentConfig = EnvironmentConfig() + """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script.""" + trainer: TrainerConfig = TrainerConfig() + """PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information.""" + runner: RunnerConfig = RunnerConfig() + """`jmp.lightning.Runner` configuration options.""" + + """Additional metadata for this run. This can be used to store arbitrary data that is not part of the config schema.""" + meta: dict[str, Any] = {} + + def clone(self, with_new_id: bool = True) -> Self: + c = copy.deepcopy(self) + if with_new_id: + c.id = BaseConfig.generate_id() + return c + + # region Seeding + + _rng: ClassVar[np.random.Generator | None] = None + + @staticmethod + def generate_id( + *, + length: int = 8, + ignore_rng: bool = False, + ) -> str: + rng = BaseConfig._rng if not ignore_rng else np.random.default_rng() + if rng is None: + warnings.warn( + "BaseConfig._rng is None. The generated IDs will not be reproducible. " + + "To fix this, call BaseConfig.set_seed(...) before generating any IDs.", + category=IdSeedWarning, + ) + rng = np.random.default_rng() + + alphabet = list(string.ascii_lowercase + string.digits) + + id = "".join(rng.choice(alphabet) for _ in range(length)) + return id + + @staticmethod + def set_seed(seed: int | None = None) -> None: + if seed is None: + seed = int(time.time() * 1000) + logger.critical(f"Seeding BaseConfig with seed {seed}") + BaseConfig._rng = np.random.default_rng(seed) + + # endregion diff --git a/src/jmp/lightning/model/modules/callback.py b/src/jmp/lightning/model/modules/callback.py new file mode 100644 index 0000000..ab84a42 --- /dev/null +++ b/src/jmp/lightning/model/modules/callback.py @@ -0,0 +1,163 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections import abc +from logging import getLogger +from typing import Any, Callable, Iterable, cast, final + +from lightning.pytorch import Callback, LightningModule +from lightning.pytorch.callbacks import LambdaCallback +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type + +log = getLogger(__name__) + +CallbackFn = Callable[[], Callback | Iterable[Callback] | None] + + +class CallbackRegistrarModuleMixin: + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._ll_callbacks: list[CallbackFn] = [] + + def register_callback( + self, + callback: Callback | Iterable[Callback] | CallbackFn | None = None, + *, + setup: Callable | None = None, + teardown: Callable | None = None, + on_fit_start: Callable | None = None, + on_fit_end: Callable | None = None, + on_sanity_check_start: Callable | None = None, + on_sanity_check_end: Callable | None = None, + on_train_batch_start: Callable | None = None, + on_train_batch_end: Callable | None = None, + on_train_epoch_start: Callable | None = None, + on_train_epoch_end: Callable | None = None, + on_validation_epoch_start: Callable | None = None, + on_validation_epoch_end: Callable | None = None, + on_test_epoch_start: Callable | None = None, + on_test_epoch_end: Callable | None = None, + on_validation_batch_start: Callable | None = None, + on_validation_batch_end: Callable | None = None, + on_test_batch_start: Callable | None = None, + on_test_batch_end: Callable | None = None, + on_train_start: Callable | None = None, + on_train_end: Callable | None = None, + on_validation_start: Callable | None = None, + on_validation_end: Callable | None = None, + on_test_start: Callable | None = None, + on_test_end: Callable | None = None, + on_exception: Callable | None = None, + on_save_checkpoint: Callable | None = None, + on_load_checkpoint: Callable | None = None, + on_before_backward: Callable | None = None, + on_after_backward: Callable | None = None, + on_before_optimizer_step: Callable | None = None, + on_before_zero_grad: Callable | None = None, + on_predict_start: Callable | None = None, + on_predict_end: Callable | None = None, + on_predict_batch_start: Callable | None = None, + on_predict_batch_end: Callable | None = None, + on_predict_epoch_start: Callable | None = None, + on_predict_epoch_end: Callable | None = None, + ): + if callback is None: + callback = LambdaCallback( + setup=setup, + teardown=teardown, + on_fit_start=on_fit_start, + on_fit_end=on_fit_end, + on_sanity_check_start=on_sanity_check_start, + on_sanity_check_end=on_sanity_check_end, + on_train_batch_start=on_train_batch_start, + on_train_batch_end=on_train_batch_end, + on_train_epoch_start=on_train_epoch_start, + on_train_epoch_end=on_train_epoch_end, + on_validation_epoch_start=on_validation_epoch_start, + on_validation_epoch_end=on_validation_epoch_end, + on_test_epoch_start=on_test_epoch_start, + on_test_epoch_end=on_test_epoch_end, + on_validation_batch_start=on_validation_batch_start, + on_validation_batch_end=on_validation_batch_end, + on_test_batch_start=on_test_batch_start, + on_test_batch_end=on_test_batch_end, + on_train_start=on_train_start, + on_train_end=on_train_end, + on_validation_start=on_validation_start, + on_validation_end=on_validation_end, + on_test_start=on_test_start, + on_test_end=on_test_end, + on_exception=on_exception, + on_save_checkpoint=on_save_checkpoint, + on_load_checkpoint=on_load_checkpoint, + on_before_backward=on_before_backward, + on_after_backward=on_after_backward, + on_before_optimizer_step=on_before_optimizer_step, + on_before_zero_grad=on_before_zero_grad, + on_predict_start=on_predict_start, + on_predict_end=on_predict_end, + on_predict_batch_start=on_predict_batch_start, + on_predict_batch_end=on_predict_batch_end, + on_predict_epoch_start=on_predict_epoch_start, + on_predict_epoch_end=on_predict_epoch_end, + ) + + if not callable(callback): + callback_ = cast(CallbackFn, lambda: callback) + else: + callback_ = callback + + self._ll_callbacks.append(callback_) + + +class CallbackModuleMixin( + CallbackRegistrarModuleMixin, mixin_base_type(LightningModule) +): + def _gather_all_callbacks(self): + modules: list[Any] = [] + if isinstance(self, CallbackRegistrarModuleMixin): + modules.append(self) + if ( + datamodule := getattr(self.trainer, "datamodule", None) + ) is not None and isinstance(datamodule, CallbackRegistrarModuleMixin): + modules.append(datamodule) + modules.extend( + module + for module in self.children() + if isinstance(module, CallbackRegistrarModuleMixin) + ) + for module in modules: + yield from module._ll_callbacks + + @final + @override + def configure_callbacks(self): + callbacks = super().configure_callbacks() + if not isinstance(callbacks, abc.Sequence): + callbacks = [callbacks] + + callbacks = list(callbacks) + for callback_fn in self._gather_all_callbacks(): + callback_result = callback_fn() + if callback_result is None: + continue + + if not isinstance(callback_result, abc.Iterable): + callback_result = [callback_result] + + for callback in callback_result: + log.info( + f"Registering {callback.__class__.__qualname__} callback {callback}" + ) + callbacks.append(callback) + + return callbacks diff --git a/src/jmp/lightning/model/modules/debug.py b/src/jmp/lightning/model/modules/debug.py new file mode 100644 index 0000000..d426af7 --- /dev/null +++ b/src/jmp/lightning/model/modules/debug.py @@ -0,0 +1,54 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger + +import torch +import torch.distributed +from varname import argname + +log = getLogger(__name__) + + +class DebugModuleMixin: + @torch.jit.unused + def breakpoint(self, rank_zero_only: bool = True): + if ( + not rank_zero_only + or not torch.distributed.is_initialized() + or torch.distributed.get_rank() == 0 + ): + breakpoint() + + if rank_zero_only and torch.distributed.is_initialized(): + _ = torch.distributed.barrier() + + @torch.jit.unused + def ensure_finite( + self, + tensor: torch.Tensor, + name: str | None = None, + throw: bool = False, + ): + if name is None: + arg_name = argname("tensor", vars_only=False) + + if arg_name is None: + raise ValueError("Could not infer name for `tensor`") + + name = str(arg_name) + + not_finite = ~torch.isfinite(tensor) + if not_finite.any(): + msg = f"Tensor {name} has {not_finite.sum().item()}/{not_finite.numel()} non-finite values." + if throw: + raise RuntimeError(msg) + else: + log.warning(msg) + return False + return True diff --git a/src/jmp/lightning/model/modules/distributed.py b/src/jmp/lightning/model/modules/distributed.py new file mode 100644 index 0000000..6d47866 --- /dev/null +++ b/src/jmp/lightning/model/modules/distributed.py @@ -0,0 +1,78 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Any, Literal, cast + +import torch.distributed +from lightning.fabric.utilities.distributed import ReduceOp +from lightning.pytorch import LightningModule +from typing_extensions import TypeVar + +from ...util.typing_utils import mixin_base_type + +T = TypeVar("T", infer_variance=True) + +ReduceOpStr = Literal[ + "avg", + "mean", + "band", + "bor", + "bxor", + "max", + "min", + "premul_sum", + "product", + "sum", +] +VALID_REDUCE_OPS = ( + "avg", + "mean", + "band", + "bor", + "bxor", + "max", + "min", + "premul_sum", + "product", + "sum", +) + + +class DistributedMixin(mixin_base_type(LightningModule)): + def all_gather_object( + self, + object: T, + group: torch.distributed.ProcessGroup | None = None, + ) -> list[T]: + if ( + not torch.distributed.is_available() + or not torch.distributed.is_initialized() + ): + return [object] + + object_list = [cast(T, None) for _ in range(self.trainer.world_size)] + torch.distributed.all_gather_object(object_list, object, group=group) + return object_list + + def barrier(self, name: str | None = None): + self.trainer.strategy.barrier(name=name) + + def reduce( + self, + tensor: torch.Tensor, + reduce_op: ReduceOp | ReduceOpStr, + group: Any | None = None, + ) -> torch.Tensor: + if isinstance(reduce_op, str): + # validate reduce_op + if reduce_op not in VALID_REDUCE_OPS: + raise ValueError( + f"reduce_op must be one of {VALID_REDUCE_OPS}, got {reduce_op}" + ) + + return self.trainer.strategy.reduce(tensor, group=group, reduce_op=reduce_op) diff --git a/src/jmp/lightning/model/modules/finite_checks.py b/src/jmp/lightning/model/modules/finite_checks.py new file mode 100644 index 0000000..8bc258f --- /dev/null +++ b/src/jmp/lightning/model/modules/finite_checks.py @@ -0,0 +1,41 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import cast + +from typing_extensions import override + +from ...callbacks.bad_gradients import PrintBadGradientsCallback, print_bad_gradients +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackModuleMixin + + +class FiniteChecksModuleMixin(mixin_base_type(CallbackModuleMixin)): + def print_bad_gradients(self): + print_bad_gradients(self) + + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _cb(): + nonlocal self + config = cast(BaseConfig, self.hparams) + if ( + not config.trainer.optimizer.grad_finite_checks + and not config.trainer.optimizer.grad_none_checks + ): + return None + + return PrintBadGradientsCallback( + none_grads=config.trainer.optimizer.grad_none_checks, + nonfinite_grads=config.trainer.optimizer.grad_finite_checks, + ) + + self.register_callback(_cb) diff --git a/src/jmp/lightning/model/modules/log_dir.py b/src/jmp/lightning/model/modules/log_dir.py new file mode 100644 index 0000000..576273c --- /dev/null +++ b/src/jmp/lightning/model/modules/log_dir.py @@ -0,0 +1,31 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path + +from lightning.pytorch import LightningDataModule, LightningModule + + +class LogDirMixin: + @property + def log_dir(self): + if not isinstance(self, (LightningModule, LightningDataModule)): + raise TypeError( + "log_dir can only be used on LightningModule or LightningDataModule" + ) + + if (trainer := self.trainer) is None: + raise RuntimeError("trainer is not defined") + + if (logger := trainer.logger) is None: + raise RuntimeError("trainer.logger is not defined") + + if (log_dir := logger.log_dir) is None: + raise RuntimeError("trainer.logger.log_dir is not defined") + + return Path(log_dir) diff --git a/src/jmp/lightning/model/modules/log_epoch.py b/src/jmp/lightning/model/modules/log_epoch.py new file mode 100644 index 0000000..7651e36 --- /dev/null +++ b/src/jmp/lightning/model/modules/log_epoch.py @@ -0,0 +1,57 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Protocol, cast, runtime_checkable + +from lightning.pytorch import LightningModule, Trainer +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackModuleMixin + + +@runtime_checkable +class _HasEpochProperty(Protocol): + @property + def epoch(self) -> float: ... + + +def _log_epoch_callback(module: LightningModule, trainer: Trainer, *, prefix: str): + if trainer.logger is None: + return + + config = cast(BaseConfig, module.hparams).trainer.logging + if not config.log_epoch: + return + + if not isinstance(module, _HasEpochProperty): + raise TypeError(f"Expected {prefix} to have an epoch property") + + module.log(f"{prefix}epoch", module.epoch, on_step=True, on_epoch=False) + + +class LogEpochMixin(mixin_base_type(CallbackModuleMixin)): + @property + def epoch(self): + return self.global_step / self.trainer.num_training_batches + + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.register_callback( + on_train_batch_start=lambda trainer, + module, + *args, + **kwargs: _log_epoch_callback(module, trainer, prefix="train/"), + on_validation_batch_start=lambda trainer, + module, + *args, + **kwargs: _log_epoch_callback(module, trainer, prefix="val/"), + ) diff --git a/src/jmp/lightning/model/modules/logger.py b/src/jmp/lightning/model/modules/logger.py new file mode 100644 index 0000000..87ee46e --- /dev/null +++ b/src/jmp/lightning/model/modules/logger.py @@ -0,0 +1,135 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable, Generator + +import torchmetrics +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.rank_zero import rank_zero_warn +from lightning.pytorch.utilities.types import _METRIC +from typing_extensions import override + +from ...actsave import ActSave +from ...util.typing_utils import mixin_base_type + + +@dataclass(frozen=True, kw_only=True) +class _LogContext: + prefix: str | None = None + disabled: bool | None = None + kwargs: dict[str, Any] = field(default_factory=dict) + + +class LoggerModuleMixin(mixin_base_type(LightningModule)): + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.__prefix_stack = deque[_LogContext]() + + if TYPE_CHECKING: + + @contextmanager + def log_context( + self, + prefix: str | None = None, + *, + disabled: bool | None = None, + prog_bar: bool | None = None, + logger: bool | None = None, + on_step: bool | None = None, + on_epoch: bool | None = None, + reduce_fx: str | Callable | None = None, + enable_graph: bool | None = None, + sync_dist: bool | None = None, + sync_dist_group: Any | None = None, + add_dataloader_idx: bool | None = None, + batch_size: int | None = None, + rank_zero_only: bool | None = None, + ) -> Generator[None, None, None]: ... + + else: + + @contextmanager + def log_context( + self, prefix: str | None = None, *, disabled: bool | None = None, **kwargs + ) -> Generator[None, None, None]: + self.__prefix_stack.append( + _LogContext( + prefix=prefix, + disabled=disabled, + kwargs=kwargs, + ) + ) + try: + yield + finally: + _ = self.__prefix_stack.pop() + + if TYPE_CHECKING: + + @override + def log( + self, + name: str, + value: _METRIC, + *, + prog_bar: bool = False, + logger: bool | None = None, + on_step: bool | None = None, + on_epoch: bool | None = None, + reduce_fx: str | Callable = "mean", + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_group: Any | None = None, + add_dataloader_idx: bool = True, + batch_size: int | None = None, + metric_attribute: str | None = None, + rank_zero_only: bool = False, + ) -> None: ... + + else: + + @override + def log(self, name: str, value: _METRIC, **kwargs) -> None: + # join all prefixes + prefix = "".join(c.prefix for c in self.__prefix_stack if c.prefix) + name = f"{prefix}{name}" + + # check for disabled context: + # if the topmost non-null context is disabled, then we don't log + for c in reversed(self.__prefix_stack): + if c.disabled is not None: + if c.disabled: + rank_zero_warn( + f"Skipping logging of {name} due to disabled context" + ) + return + else: + break + + fn_kwargs = {} + for c in self.__prefix_stack: + fn_kwargs.update(c.kwargs) + fn_kwargs.update(kwargs) + + self.__logger_actsave(name, value) + + return super().log(name, value, **fn_kwargs) + + def __logger_actsave(self, name: str, value: _METRIC) -> None: + ActSave.save( + { + f"logger::{name}": lambda: value.compute() + if isinstance(value, torchmetrics.Metric) + else value + } + ) diff --git a/src/jmp/lightning/model/modules/lr_monitor.py b/src/jmp/lightning/model/modules/lr_monitor.py new file mode 100644 index 0000000..e4db58c --- /dev/null +++ b/src/jmp/lightning/model/modules/lr_monitor.py @@ -0,0 +1,46 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import cast + +from lightning.pytorch.callbacks import LearningRateMonitor +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackModuleMixin + +log = getLogger(__name__) + + +class LRMonitorMixin(mixin_base_type(CallbackModuleMixin)): + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _lr_monitor_callback(): + nonlocal self + + config = cast(BaseConfig, self.hparams).trainer.logging + if not config.log_lr: + return None + + if self.logger is None: + log.warning( + "Skipping LR logging because no logger is configured. " + "Add a logger to your trainer to log learning rates." + ) + return None + + logging_interval: str | None = None + if isinstance(config.log_lr, str): + logging_interval = config.log_lr + return LearningRateMonitor(logging_interval=logging_interval) + + self.register_callback(_lr_monitor_callback) diff --git a/src/jmp/lightning/model/modules/optimizer.py b/src/jmp/lightning/model/modules/optimizer.py new file mode 100644 index 0000000..4ea7832 --- /dev/null +++ b/src/jmp/lightning/model/modules/optimizer.py @@ -0,0 +1,222 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import Any, Literal, cast + +import torch +import torch.nn as nn +import torchmetrics +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import LambdaCallback +from torch.optim import Optimizer +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackModuleMixin + +log = getLogger(__name__) + + +def grad_norm( + module: nn.Module, + norm_type: float | int | str, + group_separator: str = "/", + grad: bool = True, +) -> dict[str, float]: + """Compute each parameter's gradient's norm and their overall norm. + + The overall norm is computed over all gradients together, as if they + were concatenated into a single vector. + + Args: + module: :class:`torch.nn.Module` to inspect. + norm_type: The type of the used p-norm, cast to float if necessary. + Can be ``'inf'`` for infinity norm. + group_separator: The separator string used by the logger to group + the gradients norms in their own subfolder instead of the logs one. + + Return: + norms: The dictionary of p-norms of each parameter's gradient and + a special entry for the total p-norm of the gradients viewed + as a single vector. + """ + norm_type = float(norm_type) + if norm_type <= 0: + raise ValueError( + f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}" + ) + + if grad: + norms = { + f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type) + for name, p in module.named_parameters() + if p.grad is not None + } + if norms: + total_norm = torch.tensor(list(norms.values())).norm(norm_type) + norms[f"grad_{norm_type}_norm_total"] = total_norm + else: + norms = { + f"param_{norm_type}_norm{group_separator}{name}": p.data.norm(norm_type) + for name, p in module.named_parameters() + if p.grad is not None + } + if norms: + total_norm = torch.tensor(list(norms.values())).norm(norm_type) + norms[f"param_{norm_type}_norm_total"] = total_norm + + return norms + + +def _to_norm_type(log_grad_norm_per_param: float | str | Literal[True]): + norm_type = 2.0 + if log_grad_norm_per_param is not True: + norm_type = log_grad_norm_per_param + return norm_type + + +def _skipped_steps_on_before_optimizer_step( + trainer: Trainer, + pl_module: LightningModule, + optimizer: Optimizer, +) -> None: + if not isinstance(pl_module, OptimizerModuleMixin): + raise TypeError(f"Expected OptimizerModuleMixin, got {type(pl_module)}") + + if ( + config := cast( + BaseConfig, pl_module.hparams + ).trainer.optimizer.gradient_skipping + ) is None or not config.enabled: + return + + # Skip the step if the global step is less than the start_after_n_steps + # This is because we want to let AMP adjust the loss scale before we start + if ( + config.start_after_n_steps is not None + and pl_module.global_step < config.start_after_n_steps + ): + return + + norm = pl_module.compute_parameter_norm(optimizer, config.norm_type) + # If the norm is NaN/Inf, we don't want to skip the step + # beacuse AMP checks for NaN/Inf grads to adjust the loss scale. + if torch.isfinite(norm).all() and (norm > config.threshold).any(): + optimizer.zero_grad() + log.warning( + f"Skipping step at global step {pl_module.global_step} with norm {norm:.2f} > {config.threshold:.2f}" + ) + pl_module.grad_skipped_steps(1) + else: + pl_module.grad_skipped_steps(0) + + pl_module.log( + "train/grad_skipped_steps", + pl_module.grad_skipped_steps, + on_step=True, + on_epoch=False, + ) + pl_module._perform_norm_logging(optimizer, prefix="train/post_skip_") + + +class OptimizerModuleMixin(mixin_base_type(CallbackModuleMixin)): + grad_skipped_steps: torchmetrics.SumMetric + + @override + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + def _grad_skip_callback(): + nonlocal self + + if ( + config := cast( + BaseConfig, self.hparams + ).trainer.optimizer.gradient_skipping + ) is not None and config.enabled: + self.grad_skipped_steps = torchmetrics.SumMetric() + + return LambdaCallback( + on_before_optimizer_step=_skipped_steps_on_before_optimizer_step + ) + + self.register_callback(_grad_skip_callback) + + def compute_parameter_norm( + self, + optimizer: Optimizer | None = None, + p: float | str = 2.0, + grad: bool = True, + ) -> torch.Tensor: + if optimizer is not None: + tensors = [ + cast(torch.Tensor, p.grad if grad else p) + for group in optimizer.param_groups + for p in group["params"] + if p.grad is not None + ] + else: + tensors = [ + p.grad if grad else p for p in self.parameters() if p.grad is not None + ] + + if not tensors: + return torch.tensor(0.0, device=self.device) + + return torch.norm(torch.stack([torch.norm(g, p=p) for g in tensors]), p=p) + + def _perform_norm_logging(self, optimizer: Optimizer, prefix: str): + config = cast(BaseConfig, self.hparams) + + # Gradient norm logging + if log_grad_norm := config.trainer.optimizer.log_grad_norm: + norm = self.compute_parameter_norm( + optimizer, _to_norm_type(log_grad_norm), grad=True + ) + self.log(f"{prefix}grad_norm", norm, on_step=True, on_epoch=False) + if log_grad_norm_per_param := config.trainer.optimizer.log_grad_norm_per_param: + norm_type = _to_norm_type(log_grad_norm_per_param) + self.log_dict( + { + f"{prefix}{k}": v + for k, v in grad_norm(self, norm_type, grad=True).items() + } + ) + + # Parameter norm logging + if log_param_norm := config.trainer.optimizer.log_param_norm: + norm = self.compute_parameter_norm( + optimizer, _to_norm_type(log_param_norm), grad=False + ) + self.log(f"{prefix}param_norm", norm, on_step=True, on_epoch=False) + if ( + log_param_norm_per_param + := config.trainer.optimizer.log_param_norm_per_param + ): + norm_type = _to_norm_type(log_param_norm_per_param) + self.log_dict( + { + f"{prefix}{k}": v + for k, v in grad_norm(self, norm_type, grad=False).items() + } + ) + + @override + def configure_gradient_clipping( + self, + optimizer: Optimizer, + gradient_clip_val: int | float | None = None, + gradient_clip_algorithm: str | None = None, + ): + self._perform_norm_logging(optimizer, prefix="train/") + + super().configure_gradient_clipping( + optimizer, gradient_clip_val, gradient_clip_algorithm + ) diff --git a/src/jmp/lightning/model/modules/parameter_hooks.py b/src/jmp/lightning/model/modules/parameter_hooks.py new file mode 100644 index 0000000..6262b0b --- /dev/null +++ b/src/jmp/lightning/model/modules/parameter_hooks.py @@ -0,0 +1,55 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import Callable, cast + +import torch.nn as nn +from lightning.pytorch import LightningModule, Trainer +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackRegistrarModuleMixin + +log = getLogger(__name__) + + +class ParameterHookModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)): + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.after_backward_hooks: list[ + tuple[list[nn.Parameter], Callable[[nn.Parameter], None]] + ] = [] + + def on_after_backward(_trainer: Trainer, pl_module: LightningModule): + nonlocal self + + config = cast(BaseConfig, pl_module.hparams) + if not config.trainer.supports_parameter_hooks: + return + + log.debug("Running after_backward hooks...") + for parameters, hook in self.after_backward_hooks: + for parameter in parameters: + hook(parameter) + log.debug( + f"Done running after_backward hooks. (len={len(self.after_backward_hooks)})" + ) + + self.register_callback(on_after_backward=on_after_backward) + + def register_parameter_hook( + self, parameters: list[nn.Parameter], hook: Callable[[nn.Parameter], None] + ): + self.after_backward_hooks.append((parameters, hook)) + log.debug( + f"Registered after_backward hook {hook} for {len(parameters)} parameters" + ) diff --git a/src/jmp/lightning/model/modules/profiler.py b/src/jmp/lightning/model/modules/profiler.py new file mode 100644 index 0000000..75afaab --- /dev/null +++ b/src/jmp/lightning/model/modules/profiler.py @@ -0,0 +1,32 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from lightning.pytorch import LightningDataModule, LightningModule +from lightning.pytorch.profilers import PassThroughProfiler + +from ...util.typing_utils import mixin_base_type + + +class ProfilerMixin(mixin_base_type(LightningModule)): + @property + def profiler(self): + if not isinstance(self, (LightningModule, LightningDataModule)): + raise TypeError( + "`profiler` can only be used on LightningModule or LightningDataModule" + ) + + if (trainer := self.trainer) is None: + raise RuntimeError("trainer is not defined") + + if not hasattr(trainer, "profiler"): + raise RuntimeError("trainer does not have profiler") + + if (profiler := getattr(trainer, "profiler")) is None: + profiler = PassThroughProfiler() + + return profiler diff --git a/src/jmp/lightning/model/modules/rlp_sanity_checks.py b/src/jmp/lightning/model/modules/rlp_sanity_checks.py new file mode 100644 index 0000000..a9df2d8 --- /dev/null +++ b/src/jmp/lightning/model/modules/rlp_sanity_checks.py @@ -0,0 +1,153 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import cast + +from lightning.pytorch import LightningModule, Trainer +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackModuleMixin + +log = getLogger(__name__) + + +def _on_train_start_callback(trainer: Trainer, pl_module: LightningModule): + config = cast(BaseConfig, pl_module.hparams) + if config.trainer.reduce_lr_on_plateau_sanity_checks == "disable": + return + + # if no lr schedulers, return + if not trainer.lr_scheduler_configs: + return + + errors: list[str] = [] + disable_message = ( + "Otherwise, set `config.trainer.reduce_lr_on_plateau_sanity_checks='disable'` " + "to disable this sanity check." + ) + + for lr_scheduler_config in trainer.lr_scheduler_configs: + if not lr_scheduler_config.reduce_on_plateau: + continue + + match lr_scheduler_config.interval: + case "epoch": + # we need to make sure that the trainer runs val every `frequency` epochs + + # If `trainer.check_val_every_n_epoch` is None, then Lightning + # will run val every `int(trainer.val_check_interval)` steps. + # So, first we need to make sure that `trainer.val_check_interval` is not None first. + if trainer.check_val_every_n_epoch is None: + errors.append( + "Trainer is not running validation at epoch intervals " + "(i.e., `trainer.check_val_every_n_epoch` is None) but " + f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used." + f"Please set `config.trainer.check_val_every_n_epoch={lr_scheduler_config.frequency}`. " + + disable_message + ) + + # Second, we make sure that the trainer runs val at least every `frequency` epochs + if ( + trainer.check_val_every_n_epoch is not None + and lr_scheduler_config.frequency % trainer.check_val_every_n_epoch + != 0 + ): + errors.append( + f"Trainer is not running validation every {lr_scheduler_config.frequency} epochs but " + f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used." + f"Please set `config.trainer.check_val_every_n_epoch` to a multiple of {lr_scheduler_config.frequency}. " + + disable_message + ) + + case "step": + # In this case, we need to make sure that the trainer runs val at step intervals + # that are multiples of `frequency`. + + # First, we make sure that validation is run at step intervals + if trainer.check_val_every_n_epoch is not None: + errors.append( + "Trainer is running validation at epoch intervals " + "(i.e., `trainer.check_val_every_n_epoch` is not None) but " + f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used." + "Please set `config.trainer.check_val_every_n_epoch=None` " + f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. " + + disable_message + ) + + # Second, we make sure `trainer.val_check_interval` is an integer + if not isinstance(trainer.val_check_interval, int): + errors.append( + f"Trainer is not running validation at step intervals " + f"(i.e., `trainer.val_check_interval` is not an integer) but " + f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} is used." + "Please set `config.trainer.val_check_interval=None` " + f"and `config.trainer.val_check_interval={lr_scheduler_config.frequency}`. " + + disable_message + ) + + # Third, we make sure that the trainer runs val at least every `frequency` steps + if ( + isinstance(trainer.val_check_interval, int) + and trainer.val_check_interval % lr_scheduler_config.frequency != 0 + ): + errors.append( + f"Trainer is not running validation every {lr_scheduler_config.frequency} steps but " + f"a ReduceLRPlateau scheduler with interval={lr_scheduler_config.interval} and frequency={lr_scheduler_config.frequency} is used." + "Please set `config.trainer.val_check_interval` " + f"to a multiple of {lr_scheduler_config.frequency}. " + + disable_message + ) + + case _: + pass + + if not errors: + return + + message = ( + "ReduceLRPlateau sanity checks failed with the following errors:\n" + + "\n".join(errors) + ) + match config.trainer.reduce_lr_on_plateau_sanity_checks: + case "warn": + log.warning(message) + case "error": + raise ValueError(message) + case _: + pass + + +class RLPSanityCheckModuleMixin(mixin_base_type(CallbackModuleMixin)): + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + global _on_train_start_callback + self.register_callback(on_train_start=_on_train_start_callback) + + def determine_reduce_lr_on_plateau_interval_frequency(self): + if (trainer := self._trainer) is None: + raise RuntimeError( + "Could not determine the frequency of ReduceLRPlateau scheduler " + "because `self.trainer` is None." + ) + + # if trainer.check_val_every_n_epoch is an integer, then we run val at epoch intervals + if trainer.check_val_every_n_epoch is not None: + return "epoch", trainer.check_val_every_n_epoch + + # otherwise, we run val at step intervals + if not isinstance(trainer.val_check_batch, int): + raise ValueError( + "Could not determine the frequency of ReduceLRPlateau scheduler " + f"because {trainer.val_check_batch=} is not an integer." + ) + return "step", trainer.val_check_batch diff --git a/src/jmp/lightning/model/modules/shared_parameters.py b/src/jmp/lightning/model/modules/shared_parameters.py new file mode 100644 index 0000000..373ad36 --- /dev/null +++ b/src/jmp/lightning/model/modules/shared_parameters.py @@ -0,0 +1,70 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import Sequence, cast + +import torch.nn as nn +from lightning.pytorch import LightningModule, Trainer +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackRegistrarModuleMixin + +log = getLogger(__name__) + + +class SharedParametersModuleMixin(mixin_base_type(CallbackRegistrarModuleMixin)): + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.shared_parameters: list[tuple[nn.Parameter, int | float]] = [] + + def on_after_backward(_trainer: Trainer, pl_module: LightningModule): + nonlocal self + + config = cast(BaseConfig, pl_module.hparams) + if not config.trainer.supports_shared_parameters: + return + + log.debug(f"Scaling {len(self.shared_parameters)} shared parameters...") + warned_shared_param_no_grad = False + for p, factor in self.shared_parameters: + if not hasattr(p, "grad") or p.grad is None: + warned_shared_param_no_grad = True + continue + + _ = p.grad.data.div_(factor) + + if warned_shared_param_no_grad: + log.warning( + "Some shared parameters do not have a gradient. " + "Please check if all shared parameters are used " + "and point to PyTorch parameters." + ) + + log.debug( + f"Done scaling shared parameters. (len={len(self.shared_parameters)})" + ) + + self.register_callback(on_after_backward=on_after_backward) + + def register_shared_parameters( + self, parameters: Sequence[tuple[nn.Parameter, int | float]] + ): + for parameter, factor in parameters: + if not isinstance(parameter, nn.Parameter): + raise ValueError("Shared parameters must be PyTorch parameters") + if not isinstance(factor, (int, float)): + raise ValueError("Factor must be an integer or float") + + self.shared_parameters.append((parameter, factor)) + + log.info(f"Registered {len(parameters)} shared parameters") diff --git a/src/jmp/lightning/model/modules/wandb.py b/src/jmp/lightning/model/modules/wandb.py new file mode 100644 index 0000000..44ce917 --- /dev/null +++ b/src/jmp/lightning/model/modules/wandb.py @@ -0,0 +1,72 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import cast + +import torch.nn as nn +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.loggers import WandbLogger +from typing_extensions import override + +from ...util.typing_utils import mixin_base_type +from ..config import BaseConfig +from .callback import CallbackModuleMixin + +log = getLogger(__name__) + + +class WandbWrapperMixin(mixin_base_type(CallbackModuleMixin)): + @override + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup(trainer: Trainer, pl_module: LightningModule, stage: str): + nonlocal self + + config = cast(BaseConfig, self.hparams) + if ( + not config.trainer.logging.enabled + or not config.trainer.logging.wandb + or not config.trainer.logging.wandb.watch + or not config.trainer.logging.wandb.watch.enabled + ): + return + + if ( + logger := next( + ( + logger + for logger in trainer.loggers + if isinstance(logger, WandbLogger) + ), + None, + ) + ) is None: + log.warning("Could not find wandb logger or module to log") + return + + if (module := self.wandb_log_module()) is None: + log.warning("Could not find module to log to wandb") + return + + if getattr(self, "_model_watched", False): + return + + logger.watch( + module, + log=cast(str, config.trainer.logging.wandb.watch.log), + log_freq=config.trainer.logging.wandb.watch.log_freq, + log_graph=config.trainer.logging.wandb.watch.log_graph, + ) + setattr(self, "_model_watched", True) + + self.register_callback(setup=setup) + + def wandb_log_module(self) -> nn.Module | None: + return self diff --git a/src/jmp/lightning/modules/normalizer.py b/src/jmp/lightning/modules/normalizer.py new file mode 100644 index 0000000..f1d99aa --- /dev/null +++ b/src/jmp/lightning/modules/normalizer.py @@ -0,0 +1,28 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch + +from ..config import TypedConfig + + +class NormalizerConfig(TypedConfig): + enabled: bool = True + + mean: float = 0.0 + std: float = 1.0 + + def normalize(self, x: torch.Tensor): + if not self.enabled: + return x + return (x - self.mean) / self.std + + def denormalize(self, x: torch.Tensor): + if not self.enabled: + return x + return x * self.std + self.mean diff --git a/src/jmp/lightning/runner.py b/src/jmp/lightning/runner.py new file mode 100644 index 0000000..4ccbfcb --- /dev/null +++ b/src/jmp/lightning/runner.py @@ -0,0 +1,580 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import getpass +import os +import subprocess +import tempfile +import traceback +from collections import Counter +from contextlib import ExitStack +from datetime import timedelta +from functools import wraps +from logging import getLogger +from pathlib import Path +from typing import Generic, Protocol, Sequence, cast, overload, runtime_checkable + +import cloudpickle as pickle +from submitit import AutoExecutor +from tqdm.auto import tqdm +from typing_extensions import TypeVar, TypeVarTuple, Unpack, deprecated, override + +from .model.config import BaseConfig +from .trainer import Trainer +from .util.environment import ( + remove_slurm_environment_variables, + remove_wandb_environment_variables, +) +from .util.snapshot import snapshot_modules + +log = getLogger(__name__) + + +TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True) +TReturn = TypeVar("TReturn", default=None, infer_variance=True) +TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]]) + + +@runtime_checkable +class RunProtocol(Protocol[TConfig, TReturn, Unpack[TArguments]]): + def __call__(self, config: TConfig, *args: Unpack[TArguments]) -> TReturn: ... + + +class Runner(Generic[TConfig, TReturn, Unpack[TArguments]]): + DEFAULT_ENV = {} + SNAPSHOT_ENV_NAME = "LL_SNAPSHOT" + + @classmethod + def active_snapshot(cls) -> Path | None: + if (snapshot := os.environ.get(cls.SNAPSHOT_ENV_NAME)) is not None: + return Path(snapshot) + return None + + @override + def __init__( + self, + run: RunProtocol[TConfig, TReturn, Unpack[TArguments]], + *, + slurm_job_name: str = "jmplightning", + validate_config_before_run: bool = True, + validate_strict: bool = True, + ): + """This is the initialization function for a class that takes in a run protocol, an auto wrap run + boolean, and a slurm job name string. + + Parameters + ---------- + run : RunProtocol[TConfig, Unpack[TArguments]] + `run` is an instance of a class that implements the `RunProtocol` interface. It represents the main function or entry point of the program that will be executed. + slurm_job_name : str, optional + The `slurm_job_name` parameter is a string that represents the name of the job when submitting it to a SLURM cluster. + validate_config_before_run : bool, optional + The `validate_config_before_run` parameter is a boolean that represents whether or not to validate the configuration before running the program. + validate_strict: bool, optional + Should `validate_config_before_run` be strict? If `True`, the configuration will be validated strictly. If `False`, the configuration will be validated non-strictly. + """ + + super().__init__() + + self._run = run + self.slurm_job_name = slurm_job_name + self.validate_config_before_run = validate_config_before_run + self.validate_strict = validate_strict + self._init_kwargs = { + "slurm_job_name": slurm_job_name, + "validate_config_before_run": validate_config_before_run, + "validate_strict": validate_strict, + } + + @property + def _run_fn(self) -> RunProtocol[TConfig, TReturn, Unpack[TArguments]]: + run = self._run + + @wraps(run) + def wrapped_run(config: TConfig, *args: Unpack[TArguments]) -> TReturn: + nonlocal self + + with ExitStack() as stack: + nonlocal run + + # If `auto_call_trainer_init_from_runner`, we call `Trainer.runner_init` before running the program. + if config.runner.auto_call_trainer_init_from_runner: + stack.enter_context(Trainer.runner_init(config)) + + # If `validate_config_before_run`, we validate the configuration before running the program. + if self.validate_config_before_run: + config = type(config).model_deep_validate( + config, strict=self.validate_strict + ) + + if config.trainer.auto_wrap_trainer: + stack.enter_context(Trainer.context(config)) + log.critical("Auto-wrapping run in Trainer context") + + return run(config, *args) + + raise RuntimeError("ExitStack should never raise an exception") + + return wrapped_run + + @staticmethod + def _resolve_run( + run: TConfig | tuple[TConfig, Unpack[TArguments]], + copy_config: bool = True, + reset_id: bool = False, + ): + if isinstance(run, tuple): + (config, *args) = run + else: + config = cast(TConfig, run) + args = [] + args = cast(tuple[Unpack[TArguments]], args) + if copy_config: + config = copy.deepcopy(config) + if reset_id: + config.id = BaseConfig.generate_id(ignore_rng=True) + return (config, args) + + @staticmethod + def _resolve_runs( + runs: Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]], + copy_config: bool = True, + reset_id: bool = False, + ): + resolved: list[tuple[TConfig, tuple[Unpack[TArguments]]]] = [] + for run in runs: + resolved.append( + Runner._resolve_run(run, copy_config=copy_config, reset_id=reset_id) + ) + + return resolved + + @deprecated("Use __call__ instead") + @overload + def local( + self, + run: TConfig | tuple[TConfig, Unpack[TArguments]], + /, + *, + env: dict[str, str] | None = None, + reset_id: bool = True, + ) -> TReturn: ... + + @deprecated("Use __call__ instead") + @overload + def local( + self, + run_1: TConfig | tuple[TConfig, Unpack[TArguments]], + run_2: TConfig | tuple[TConfig, Unpack[TArguments]], + /, + *runs: TConfig | tuple[TConfig, Unpack[TArguments]], + env: dict[str, str] | None = None, + reset_id: bool = True, + ) -> list[TReturn]: ... + + @deprecated("Use __call__ instead") + def local( + self, + *runs: TConfig | tuple[TConfig, Unpack[TArguments]], + env: dict[str, str] | None = None, + reset_id: bool = True, + ): + return_values: list[TReturn] = [] + for run in runs: + config, args = self._resolve_run(run) + if reset_id: + config.id = BaseConfig.generate_id(ignore_rng=True) + + env = {**self.DEFAULT_ENV, **(env or {})} + env_old = {k: os.environ.get(k, None) for k in env} + os.environ.update(env) + try: + return_value = self._run_fn(config, *args) + return_values.append(return_value) + finally: + for k, v in env_old.items(): + if v is None: + _ = os.environ.pop(k, None) + else: + os.environ[k] = v + + return return_values[0] if len(return_values) == 1 else return_values + + def __call__( + self, + runs: Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]], + env: dict[str, str] | None = None, + reset_id: bool = True, + ): + """ + Runs a list of configs locally. + + Parameters + ---------- + runs : Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]] + A sequence of runs to submit. + env : dict[str, str], optional + Additional environment variables to set. + reset_id : bool, optional + Whether to reset the id of the runs before launching them. + """ + return_values: list[TReturn] = [] + for run in runs: + config, args = self._resolve_run(run) + if reset_id: + config.id = BaseConfig.generate_id(ignore_rng=True) + + env = {**self.DEFAULT_ENV, **(env or {})} + env_old = {k: os.environ.get(k, None) for k in env} + os.environ.update(env) + try: + return_value = self._run_fn(config, *args) + return_values.append(return_value) + finally: + for k, v in env_old.items(): + if v is None: + _ = os.environ.pop(k, None) + else: + os.environ[k] = v + + return return_values[0] if len(return_values) == 1 else return_values + + def _launch_session( + self, + config_paths: list[Path], + conda_env: str | None, + session_name: str, + env: dict[str, str], + what_if: bool = False, + ): + # All we need to do here is launch `python -m jmp.lightning.local_sessions_runner` with the config paths as arguments. The `local_sessions_runner` will take care of the rest. + # Obviously, the command above needs to be run in a screen session, so we can come back to it later. + + if not conda_env: + command = ( + ["screen", "-dmS", session_name] + + ["python", "-m", "jmp.lightning.local_sessions_runner"] + + [str(p.absolute()) for p in config_paths] + ) + else: + command = ( + ["screen", "-dmS", session_name] + + [ + "conda", + "run", + "--live-stream", + "-n", + conda_env, + "python", + "-m", + "jmp.lightning.local_sessions_runner", + ] + + [str(p.absolute()) for p in config_paths] + ) + if not what_if: + log.critical(f"Launching session with command: {command}") + _ = subprocess.run(command, env=env, check=True) + + return command + + def local_sessions( + self, + runs: Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]], + sessions: int | list[dict[str, str]], + config_pickle_save_path: Path | None = None, + reset_id: bool = True, + what_if: bool = False, + ): + """ + Launches len(sessions) local runs in different environments using `screen`. + + Parameters + ---------- + runs : Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]] + A sequence of runs to launch. + sessions : list[dict[str, str]] + A list of environment variables to use for each session. + config_pickle_save_path : Path, optional + The path to save the config pickles to. If `None`, a temporary directory will be created. + reset_id : bool, optional + Whether to reset the id of the runs before launching them. + what_if : bool, optional + If `True`, the sessions will not be launched, but the command to launch them will be printed. + + Returns + ------- + list[TReturn] + A list of names for each screen session. + """ + + if isinstance(sessions, int): + sessions = [{} for _ in range(sessions)] + + # This only works in conda environments, so we need to make sure we're in one + if (current_env := os.environ.get("CONDA_DEFAULT_ENV")) is None: + raise RuntimeError("This function only works in conda environments.") + + if config_pickle_save_path is None: + config_pickle_save_path = Path(tempfile.mkdtemp()) + + resolved_runs = self._resolve_runs(runs, reset_id=reset_id) + self._validate_runs(resolved_runs) + + # Save all configs to pickle files + config_paths: list[Path] = [] + for i, config in enumerate(resolved_runs): + config_path = config_pickle_save_path / f"ll_{i:03d}.pkl" + config_paths.append(config_path) + config = tuple([config[0], *config[1]]) + with config_path.open("wb") as f: + pickle.dump((self._run, self._init_kwargs, config), f) + + # Launch all sessions + names: list[str] = [] + commands: list[str] = [] + n_sessions = len(sessions) + for i, session in enumerate(sessions): + session_env = {**self.DEFAULT_ENV, **session} + # Get the shared project name + project_names = set([config.project for config, _ in resolved_runs]) + if len(project_names) == 1: + project = project_names.pop() + else: + project = "session" + session_name = f"ll_{project}_{i:03d}" + command = self._launch_session( + config_paths, + current_env, + session_name, + session_env, + what_if=what_if, + ) + names.append(session_name) + if what_if: + # log.critical(f"Sesssion {i+1}/{n_sessions} command: {command_str}") + command_prefix = " ".join(f'{k}="{v}"' for k, v in session_env.items()) + command_str = " ".join(command) + commands.append(f"{command_prefix} {command_str}") + else: + log.critical(f"Launched session {i+1}/{n_sessions}") + + if what_if: + # Print the full command so the user can copy-paste it + print( + "The sessions were not launched because `what_if` was set. Please copy-paste the following command to launch the sessions." + ) + for command in commands: + print(command) + + return names + + @staticmethod + def _n_gpus(): + import torch + + return torch.cuda.device_count() + + def local_session_per_gpu( + self, + runs: Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]], + config_pickle_save_path: Path | None = None, + reset_id: bool = True, + what_if: bool = False, + ): + """ + Launches len(sessions) local runs in different environments using `screen`. + + Parameters + ---------- + runs : Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]] + A sequence of runs to launch. + config_pickle_save_path : Path, optional + The path to save the config pickles to. If `None`, a temporary directory will be created. + reset_id : bool, optional + Whether to reset the id of the runs before launching them. + what_if : bool, optional + If `True`, the sessions will not be launched, but the command to launch them will be printed. + + Returns + ------- + list[TReturn] + A list of names for each screen session. + """ + # Get the number of GPUs + n_gpus = self._n_gpus() + log.critical(f"Detected {n_gpus} GPUs. Launching one session per GPU.") + + # Create a session for each GPU + sessions = [{"CUDA_VISIBLE_DEVICES": str(i)} for i in range(n_gpus)] + + # Launch the sessions + return self.local_sessions( + runs, + sessions, + config_pickle_save_path=config_pickle_save_path, + reset_id=reset_id, + what_if=what_if, + ) + + def fast_dev_run( + self, + runs: Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]], + env: dict[str, str] | None = None, + n_batches: int = 1, + stop_on_error: bool = True, + ): + """ + Runs a list of configs locally with `LightningTrainer.fast_dev_run = True`. + + Parameters + ---------- + runs : Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]] + A sequence of runs to submit. + env : dict[str, str], optional + Additional environment variables to set. + n_batches : int, optional + The number of batches to run for `fast_dev_run`. + stop_on_error : bool, optional + Whether to stop on error. + """ + resolved_runs = self._resolve_runs(runs) + self._validate_runs(resolved_runs) + + return_values: list[TReturn] = [] + + for config, args in tqdm(resolved_runs, desc="Fast dev run"): + run_id = config.id + run_name = config.name + try: + config.trainer.fast_dev_run = n_batches + return_value = self.local((config, *args), env=env, reset_id=True) + return_values.append(return_value) + except BaseException as e: + # print full traceback + log.critical(f"Error in run with {run_id=} ({run_name=}): {e}") + traceback.print_exc() + if stop_on_error: + raise + + return return_values + + @staticmethod + def _validate_runs(runs: list[tuple[TConfig, tuple[Unpack[TArguments]]]]): + if not runs: + raise ValueError("No run configs provided.") + + id_counter = Counter(config.id for config, _ in runs if config.id is not None) + for id, count in id_counter.items(): + if count > 1: + raise ValueError(f"Duplicate id {id=}") + + @remove_slurm_environment_variables() + @remove_wandb_environment_variables() + def submit( + self, + runs: Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]], + *, + gpus: int, + nodes: int, + partition: str, + cpus_per_task: int, + snapshot: bool | Path, + constraint: str | None = None, + timeout: timedelta | None = None, + memory: int | None = None, + email: str | None = None, + slurm_additional_parameters: dict[str, str] | None = None, + slurm_setup: list[str] | None = None, + snapshot_base: Path | None = None, + env: dict[str, str] | None = None, + ): + """ + Submits a list of configs to a SLURM cluster. + + Parameters + ---------- + runs : Sequence[TConfig] | Sequence[tuple[TConfig, Unpack[TArguments]]] + A sequence of runs to submit. + gpus : int + The number of GPUs per node. + nodes : int + The number of nodes. + partition : str + The name of the partition to submit to. + cpus_per_task : int + The number of CPUs per task. + snapshot : bool | Path + If `True`, snapshots the current environment. If a `Path` is provided, it will be used as the snapshot directory. + constraint : str, optional + The name of the constraint to use. + timeout : timedelta, optional + The maximum time to run the job for. + memory : int, optional + The amount of memory to use. + email : str, optional + The email to send notifications to. + slurm_additional_parameters : dict[str, str], optional + Additional parameters to pass to the SLUR + """ + resolved_runs = self._resolve_runs(runs) + self._validate_runs(resolved_runs) + + if snapshot_base is None: + current_user = getpass.getuser() + snapshot_base = Path(f"/checkpoint/{current_user}/ll_snapshots/") + + if snapshot is True: + snapshot = snapshot_modules(snapshot_base, ["jmp", "submitit"]).absolute() + + env = {**self.DEFAULT_ENV, **(env or {})} + + base_path = Path(".") / "slurm_logs" + base_path.mkdir(exist_ok=True, parents=True) + + additional_parameters = {} + if email: + additional_parameters.update({"mail_user": email, "mail_type": "FAIL"}) + if constraint: + additional_parameters.update({"constraint": constraint}) + if slurm_additional_parameters: + additional_parameters.update(slurm_additional_parameters) + + setup = [] + if env: + setup.extend(f"export {k}={v}" for k, v in env.items()) + if slurm_setup: + setup.extend(slurm_setup) + if snapshot: + snapshot_str = str(snapshot.resolve().absolute()) + setup.append(f"export {self.SNAPSHOT_ENV_NAME}={snapshot_str}") + setup.append(f"export PYTHONPATH={snapshot_str}:$PYTHONPATH") + + parameters_kwargs = dict( + name=self.slurm_job_name, + mem_gb=memory, + cpus_per_task=cpus_per_task, + tasks_per_node=gpus, + gpus_per_node=gpus, + nodes=nodes, + slurm_partition=partition, + slurm_additional_parameters=additional_parameters, + slurm_setup=setup, + ) + if timeout: + parameters_kwargs["timeout_min"] = int(timeout.total_seconds() / 60) + + executor = AutoExecutor(folder=base_path / "%j") + executor.update_parameters(**parameters_kwargs) + + map_array_args = list(zip(*[(c, *args) for c, args in resolved_runs])) + log.critical(f"Submitting {len(resolved_runs)} jobs to {partition}.") + jobs = executor.map_array(self._run_fn, *map_array_args) + for job, (config, _) in zip(jobs, resolved_runs): + log.critical(f"[id={config.id}] Submitted job: {job.job_id} to {partition}") + return jobs diff --git a/src/jmp/lightning/trainer/__init__.py b/src/jmp/lightning/trainer/__init__.py new file mode 100644 index 0000000..53143e9 --- /dev/null +++ b/src/jmp/lightning/trainer/__init__.py @@ -0,0 +1,11 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .trainer import Trainer + +__all__ = ["Trainer"] diff --git a/src/jmp/lightning/trainer/logging.py b/src/jmp/lightning/trainer/logging.py new file mode 100644 index 0000000..c0f3c8a --- /dev/null +++ b/src/jmp/lightning/trainer/logging.py @@ -0,0 +1,131 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from pathlib import Path +from typing import Literal + +from lightning.pytorch.loggers import Logger +from lightning.pytorch.loggers.csv_logs import CSVLogger +from lightning.pytorch.loggers.tensorboard import TensorBoardLogger +from lightning.pytorch.loggers.wandb import WandbLogger + +from ..model.config import BaseConfig + +log = getLogger(__name__) + + +def default_root_dir(config: BaseConfig, *, logs_dirname: str = "lightning_logs"): + base_path = (Path.cwd() / logs_dirname).resolve().absolute() + path = base_path / config.id + path.mkdir(parents=True, exist_ok=True) + return path + + +def _default_loggers( + *, + base_path: str | Path = ".", + id: str | None = None, + name: str | None = None, + project: str | None = None, + csv: bool = True, + tensorboard: bool = True, + wandb: bool = True, + log_model: bool | Literal["all"] = False, + notes: str | None = None, + tags: list[str] | None = None, +) -> list[Logger]: + base_path = Path(base_path) + + loggers: list[Logger] = [] + if wandb: + log.info(f"Creating W&B logger for {project or 'lightning_logs'} with {id=}.") + loggers.append( + WandbLogger( + save_dir=base_path, + project=project or "lightning_logs", + name=name, + version=id, + log_model=log_model, + notes=notes, + tags=tags, + ) + ) + if csv: + log.info(f"Creating CSV logger for {base_path / 'csv'} with {id=}.") + loggers.append( + CSVLogger( + save_dir=base_path / "csv", + name=name or ".", + version=id, + ) + ) + if tensorboard: + log.info( + f"Creating TensorBoard logger for {base_path / 'tensorboard'} with {id=}." + ) + loggers.append( + TensorBoardLogger( + save_dir=base_path / "tensorboard", + name=name, + version=id, + ) + ) + return loggers + + +def loggers_from_config(config: BaseConfig): + logging_config = config.trainer.logging + if not logging_config.enabled or config.trainer.logger is False: + return [] + + wandb_log_model = False + if logging_config.wandb is not None: + match wandb_log_model := logging_config.wandb.log_model: + case True | False | "all": + log.info(f"W&B logging model: {wandb_log_model}.") + case _: + raise ValueError(f"Invalid wandb log_model value {wandb_log_model}.") + + return _default_loggers( + base_path=default_root_dir(config), + id=config.id, + name=config.name, + project=config.project, + csv=logging_config.csv is not None and logging_config.csv.enabled, + tensorboard=logging_config.tensorboard is not None + and logging_config.tensorboard.enabled, + wandb=logging_config.wandb is not None and logging_config.wandb.enabled, + log_model=wandb_log_model, + tags=config.tags, + notes=( + "\n".join(f"- {note}" for note in config.notes) if config.notes else None + ), + ) + + +def validate_logger(logger: Logger, run_id: str): + match logger: + case CSVLogger() | TensorBoardLogger() | WandbLogger(): + if logger.version != run_id: + raise ValueError( + f"{logger.__class__.__qualname__} version {logger.version} does not match run_id {run_id}" + ) + case _: + log.warning( + f"Logger {logger.__class__.__qualname__} does not support run_id, ignoring." + ) + + +def finalize_loggers(loggers: list[Logger]): + for logger in loggers: + match logger: + case WandbLogger(_experiment=experiment) if experiment is not None: + _ = experiment.finish() + case _: + pass diff --git a/src/jmp/lightning/trainer/trainer.py b/src/jmp/lightning/trainer/trainer.py new file mode 100644 index 0000000..9ac6785 --- /dev/null +++ b/src/jmp/lightning/trainer/trainer.py @@ -0,0 +1,449 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import contextlib +import logging +from collections import abc +from pathlib import Path +from types import NoneType +from typing import Any, Callable + +import torch +from lightning.pytorch import LightningModule +from lightning.pytorch import Trainer as LightningTrainer +from lightning.pytorch.callbacks import ModelCheckpoint, OnExceptionCheckpoint +from lightning.pytorch.plugins.environments import SLURMEnvironment +from lightning.pytorch.profilers import Profiler +from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT +from lightning_fabric.utilities.types import _PATH +from typing_extensions import override + +from ..model.config import ( + BaseConfig, + BaseProfilerConfig, + PythonLogging, + RunnerOutputSaveConfig, +) +from ..util import seed +from ..util.environment import set_additional_env_vars +from ..util.typing_utils import copy_method_with_param +from .logging import ( + default_root_dir, + finalize_loggers, + loggers_from_config, + validate_logger, +) + +log = logging.getLogger(__name__) + + +def _save_output_dir(root_config: BaseConfig, config: RunnerOutputSaveConfig): + if not (dirpath := config.dirpath): + dirpath = default_root_dir(root_config, logs_dirname="ll_runner_logs") + dirpath = Path(dirpath).resolve() + + # Make sure that the directory exists + dirpath.mkdir(parents=True, exist_ok=True) + + return dirpath + + +def _default_log_handlers(root_config: BaseConfig): + if (config := root_config.runner.save_output) is None or not config.enabled: + return + + # Get the directory path + dirpath = _save_output_dir(root_config, config) + + # Capture the logs to `dirpath`/log.log + log_file = dirpath / "log.log" + log_file.touch(exist_ok=True) + yield logging.FileHandler(log_file) + + +def _setup_logger(root_config: BaseConfig, config: PythonLogging): + if config.lovely_tensors: + try: + import lovely_tensors + + lovely_tensors.monkey_patch() + except ImportError: + log.warning( + "Failed to import lovely-tensors. Ignoring pretty PyTorch tensor formatting" + ) + + if config.lovely_numpy: + try: + import lovely_numpy + + lovely_numpy.set_config(repr=lovely_numpy.lovely) + except ImportError: + log.warning( + "Failed to import lovely-numpy. Ignoring pretty numpy array formatting" + ) + + log_handlers: list[logging.Handler] = [*_default_log_handlers(root_config)] + if config.rich: + try: + from rich.logging import RichHandler + + log_handlers.append(RichHandler()) + except ImportError: + log.warning( + "Failed to import rich. Falling back to default Python logging." + ) + + logging.basicConfig( + level=config.log_level, + format="%(message)s", + datefmt="[%X]", + handlers=log_handlers, + ) + + logging.basicConfig(level=config.log_level) + + +class Trainer(LightningTrainer): + _finalizers: list[Callable[[], None]] = [] + + def finalize(self): + """ + Call this method to clean up after training. + """ + finalize_loggers(self.loggers) + + @staticmethod + def ll_default_root_dir( + config: BaseConfig, *, logs_dirname: str = "lightning_logs" + ): + return default_root_dir(config, logs_dirname=logs_dirname) + + @classmethod + def setup_python_logging(cls, config: BaseConfig): + _setup_logger(config, config.trainer.python_logging) + + @classmethod + @contextlib.contextmanager + def output_save_context(cls, root_config: BaseConfig): + if (config := root_config.runner.save_output) is None or not config.enabled: + yield + return + + # Get the directory path + dirpath = _save_output_dir(root_config, config) + + # Capture the stdout and stderr logs to `dirpath`/stdout.log and `dirpath`/stderr.log + stdout_log = dirpath / "stdout.log" + stderr_log = dirpath / "stderr.log" + stdout_log.touch(exist_ok=True) + stderr_log.touch(exist_ok=True) + with stdout_log.open("a") as file: + with contextlib.redirect_stdout(file): + with stderr_log.open("a") as file: + with contextlib.redirect_stderr(file): + yield + + @classmethod + @contextlib.contextmanager + def ll_initialize(cls, config: BaseConfig): + with contextlib.ExitStack() as stack: + if not config.runner.auto_call_trainer_init_from_runner: + stack.enter_context(cls.runner_init(config)) + + if config.trainer.auto_set_default_root_dir: + if config.trainer.default_root_dir: + raise ValueError( + "You have set both `config.trainer.default_root_dir` and `config.trainer.auto_set_default_root_dir`. " + "Please set only one of them." + ) + config.trainer.default_root_dir = str( + cls.ll_default_root_dir(config).absolute() + ) + log.critical(f"Setting {config.trainer.default_root_dir=}.") + + yield + + @classmethod + @contextlib.contextmanager + def runner_init(cls, config: BaseConfig): + with contextlib.ExitStack() as stack: + cls.setup_python_logging(config) + # Save stdout/stderr to a file + stack.enter_context(Trainer.output_save_context(config)) + yield + + @classmethod + def ll_default_callbacks(cls, config: BaseConfig): + if config.trainer.on_exception_checkpoint: + if config.trainer.default_root_dir is None: + raise ValueError( + "You must specify `config.trainer.default_root_dir` " + "to use `config.trainer.on_exception_checkpoint`." + ) + log_dir = Path(config.trainer.default_root_dir) + yield OnExceptionCheckpoint(log_dir, filename=f"on_exception_{config.id}") + + @classmethod + @contextlib.contextmanager + def context(cls, config: BaseConfig): + with contextlib.ExitStack() as stack: + stack.enter_context(cls.ll_initialize(config)) + + cls._finalizers.clear() + if config.trainer.seed is not None: + stack.enter_context( + seed.seed_context( + config.trainer.seed, workers=config.trainer.seed_workers + ) + ) + + additional_nccl_env_vars: dict[str, str] = {} + if config.trainer.set_nccl_optimal_params: + # We need to set these env vars before the NCCL library is loaded. + # Reportedly, the training performance can be improved quite a bit (see). + # Details on all available env vars: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html + additional_nccl_env_vars["NCCL_NSOCKS_PERTHREAD"] = "4" + additional_nccl_env_vars["NCCL_SOCKET_NTHREADS"] = "2" + + if (precision := config.trainer.set_float32_matmul_precision) is not None: + torch.set_float32_matmul_precision(precision) + + stack.enter_context( + set_additional_env_vars( + config.trainer.additional_env_vars | additional_nccl_env_vars + ) + ) + + try: + yield + finally: + n_finalizers = 0 + for finalizer in reversed(cls._finalizers): + finalizer() + n_finalizers += 1 + + cls._finalizers.clear() + log.critical( + f"Ran {n_finalizers} finalizers for {cls.__name__} cleanup." + ) + + @classmethod + def _update_kwargs(cls, config: BaseConfig, kwargs_ctor: dict[str, Any]): + kwargs = { + "accelerator": config.trainer.accelerator, + "strategy": config.trainer.strategy, + "devices": config.trainer.devices, + "num_nodes": config.trainer.num_nodes, + "precision": config.trainer.precision, + "logger": config.trainer.logger, + "fast_dev_run": config.trainer.fast_dev_run, + "max_epochs": config.trainer.max_epochs, + "min_epochs": config.trainer.min_epochs, + "max_steps": config.trainer.max_steps, + "min_steps": config.trainer.min_steps, + "max_time": config.trainer.max_time, + "limit_train_batches": config.trainer.limit_train_batches, + "limit_val_batches": config.trainer.limit_val_batches, + "limit_test_batches": config.trainer.limit_test_batches, + "limit_predict_batches": config.trainer.limit_predict_batches, + "overfit_batches": config.trainer.overfit_batches, + "val_check_interval": config.trainer.val_check_interval, + "check_val_every_n_epoch": config.trainer.check_val_every_n_epoch, + "num_sanity_val_steps": config.trainer.num_sanity_val_steps, + "log_every_n_steps": config.trainer.log_every_n_steps, + "enable_checkpointing": config.trainer.enable_checkpointing, + "enable_progress_bar": config.trainer.enable_progress_bar, + "enable_model_summary": config.trainer.enable_model_summary, + "accumulate_grad_batches": config.trainer.accumulate_grad_batches, + "deterministic": config.trainer.deterministic, + "benchmark": config.trainer.benchmark, + "inference_mode": config.trainer.inference_mode, + "use_distributed_sampler": config.trainer.use_distributed_sampler, + "detect_anomaly": config.trainer.detect_anomaly, + "barebones": config.trainer.barebones, + "plugins": config.trainer.plugins, + "sync_batchnorm": config.trainer.sync_batchnorm, + "reload_dataloaders_every_n_epochs": config.trainer.reload_dataloaders_every_n_epochs, + } + # if config.trainer.automatic_gradient_clip: + # kwargs["gradient_clip_val"] = config.trainer.gradient_clip_val + # kwargs["gradient_clip_algorithm"] = config.trainer.gradient_clip_algorithm + if ( + grad_clip_config := config.trainer.optimizer.gradient_clipping + ) is not None and grad_clip_config.enabled: + kwargs["gradient_clip_algorithm"] = grad_clip_config.algorithm + kwargs["gradient_clip_val"] = grad_clip_config.value + + if profiler := config.trainer.profiler: + # If the profiler is an ProfilerConfig instance, then we instantiate it. + if isinstance(profiler, BaseProfilerConfig): + profiler = profiler.construct_profiler() + # Make sure that the profiler is an instance of `Profiler`. + if not isinstance(profiler, Profiler): + raise ValueError(f"{profiler=} is not an instance of `{Profiler}`.") + + # Otherwise, if the profiler is a string (e.g., "simpe", "advanced", "pytorch"), + # then we just pass it through. + kwargs["profiler"] = profiler + + kwargs.update(kwargs_ctor) + + kwargs["plugins"] = [] + if config.trainer.plugins is not None: + kwargs["plugins"].extend(config.trainer.plugins) + if (plugins := kwargs_ctor.get("plugins")) is not None: + plugins = [plugins] if not isinstance(plugins, list) else plugins + kwargs["plugins"].extend(plugins) + + if config.trainer.logger is False: + log.critical(f"Disabling logger because {config.trainer.logger=}.") + kwargs["logger"] = False + elif kwargs.get("logger") is False: + log.critical(f"Disabling logger because {kwargs.get('logger')=}.") + + if ( + existing_loggers := kwargs.get("logger") + ) is not False and config.trainer.auto_set_loggers: + if int(config.trainer.fast_dev_run) > 0: + log.critical("Disabling loggers because fast_dev_run is enabled.") + else: + loggers = loggers_from_config(config) + if existing_loggers is not None and not isinstance( + existing_loggers, bool + ): + if not isinstance(existing_loggers, list): + existing_loggers = [existing_loggers] + loggers.extend(existing_loggers) + + kwargs["logger"] = loggers + + if kwargs.get("num_nodes") == "auto": + # when num_nodes is auto, we need to detect the number of nodes + # when on slurm, this would be the number of SLURM nodes allocated + if SLURMEnvironment.detect(): + from submitit import JobEnvironment + + job = JobEnvironment() + if not job.activated(): + raise ValueError( + "SLURMEnvironment detected through PL but not submitit. This is a bug." + ) + + kwargs["num_nodes"] = job.num_nodes + log.critical( + f"Setting num_nodes to {job.num_nodes} (detected through submitit)." + ) + # otherweise, we assume 1 node + else: + kwargs["num_nodes"] = 1 + log.critical("Setting num_nodes to 1 (no SLURM detected).") + + if config.trainer.default_root_dir: + kwargs["default_root_dir"] = str(config.trainer.default_root_dir) + kwargs.update(config.trainer.additional_trainer_kwargs) + + # Set the callbacks + callbacks = kwargs.get("callbacks", []) + if not isinstance(callbacks, list): + callbacks = [callbacks] + callbacks.extend(cls.ll_default_callbacks(config)) + kwargs["callbacks"] = callbacks + + return kwargs + + @override + @copy_method_with_param( + LightningTrainer.__init__, + param_type=BaseConfig, + return_type=NoneType, + ) + def __init__(self, config: BaseConfig, *args, **kwargs): + self._ll_config = config + kwargs = self._update_kwargs(config, kwargs) + log.critical(f"LightningTrainer.__init__ with {args=} and {kwargs=}.") + super().__init__(*args, **kwargs) + + if config.trainer.enable_logger_validation: + for logger in self.loggers: + validate_logger(logger, config.id) + + if config.trainer.checkpoint_last_by_default: + self._patch_checkpoint_last_by_default() + if config.trainer.auto_add_trainer_finalizer: + type(self)._finalizers.append(self.finalize) + + # Print out the log dir, so that we can easily find it in the logs. + if log_dir := self.log_dir: + log_dir = str(Path(log_dir).resolve()) + log.critical(f"LightningTrainer log directory: {self.log_dir}.") + + def _patch_checkpoint_last_by_default(self): + """ + Patch the default ModelCheckpoint callback to save the last checkpoint by default. + """ + enable_checkpointing = ( + True + if self._ll_config.trainer.enable_checkpointing is None + else self._ll_config.trainer.enable_checkpointing + ) + if not enable_checkpointing: + return + + if not (callbacks := getattr(self, "callbacks", None)) or not isinstance( + callbacks, abc.Iterable + ): + return + + if ( + model_ckpt := next( + (c for c in callbacks if isinstance(c, ModelCheckpoint)), None + ) + ) is None: + return + + log.critical(f"Setting {model_ckpt.__class__.__name__}.save_last=True.") + model_ckpt.save_last = True + # hacky: call the `__validate_init_configuration` method to ensure that the `save_last` parameter is valid. + # model_ckpt.__validate_init_configuration() <- this doesn't work because it's a private method + if ( + validate_init_configuration := getattr( + model_ckpt, + f"_{model_ckpt.__class__.__name__}__validate_init_configuration", + None, + ) + ) is not None and callable(validate_init_configuration): + validate_init_configuration() + else: + log.warning( + f"Failed to find {model_ckpt.__class__.__name__}.__validate_init_configuration. " + "This means that we cannot validate the `save_last` parameter for ModelCheckpoint." + ) + + @override + def _run( + self, model: LightningModule, ckpt_path: _PATH | None = None + ) -> _EVALUATE_OUTPUT | _PREDICT_OUTPUT | None: + """ + Lightning doesn't support gradient clipping with manual optimization. + We patch the `Trainer._run` method to throw if gradient clipping is enabled + and `model.automatic_optimization` is False. + """ + if not model.automatic_optimization and ( + self.gradient_clip_val is not None + or self.gradient_clip_algorithm is not None + ): + raise ValueError( + "Gradient clipping is not supported with manual optimization. " + f"Please set {model.__class__.__name__}.automatic_optimization to True " + "or disable automatic gradient clipping. " + "If you want to use gradient clipping with manual optimization, you can " + "set `config.trainer.automatic_gradient_clip=False` and " + "use the values in `config.trainer.gradient_clip_val` and `config.trainer.gradient_clip_algorithm`." + ) + + return super()._run(model, ckpt_path) diff --git a/src/jmp/lightning/util/distributed/debug.py b/src/jmp/lightning/util/distributed/debug.py new file mode 100644 index 0000000..82eca31 --- /dev/null +++ b/src/jmp/lightning/util/distributed/debug.py @@ -0,0 +1,69 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from datetime import timedelta +from functools import wraps +from logging import getLogger + +import torch.distributed + +logger = getLogger(__name__) + + +def _wrap(name: str, log: bool = False): + fn = getattr(torch.distributed, name) + + @wraps(fn) + def wrapper(*args, **kwargs): + nonlocal log + + if log: + logger.critical( + f"Calling {fn.__name__} with args: {args}, kwargs: {kwargs}" + ) + _ = torch.distributed.og_barrier() + + return fn(*args, **kwargs) + + setattr(torch.distributed, name, wrapper) + + +def debug_distributed( + timeout: int | float | timedelta = 120.0, + log: bool = False, +): + if not isinstance(timeout, timedelta): + timeout = timedelta(seconds=timeout) + logger.critical(f"Patching torch.distributed for debug. Timeout: {timeout}. ") + + @wraps(torch.distributed.barrier) + def barrier_fn(group=torch.distributed.GroupMember.WORLD): + logger.critical("Calling torch.distributed.barrier.") + return torch.distributed.monitored_barrier( + group, timeout=timeout, wait_all_ranks=True + ) + + torch.distributed.og_barrier = torch.distributed.barrier + torch.distributed.barrier = barrier_fn + + fn_names = [ + # "send", + # "recv", + "broadcast", + "all_reduce", + "reduce", + "all_gather", + "all_gather_object", + "gather", + "scatter", + "reduce_scatter", + "all_to_all", + ] + for name in fn_names: + _wrap(name, log=log) + logger.critical(f"Wrapped torch.distributed.{name}.") diff --git a/src/jmp/lightning/util/environment.py b/src/jmp/lightning/util/environment.py new file mode 100644 index 0000000..66bf0cb --- /dev/null +++ b/src/jmp/lightning/util/environment.py @@ -0,0 +1,101 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +from contextlib import contextmanager +from logging import getLogger + +log = getLogger(__name__) + + +@contextmanager +def remove_slurm_environment_variables(): + """ + SLURM_CPU_BIND_* environment variables are set by SLURM in the current environment. + We need to remove all of these environment variables during the codepath in which we create the new SLURM runs, so that the new SLURM runs do not inherit the environment variables from the current environment. + To make things easier, we will patch the environment to remove all "SLURM_" environment variables. + Otherwise, the runs will faill with an error like shown below: + srun: error: CPU binding outside of job step allocation, allocated CPUs are: 0x01F000000001F0000000. + srun: error: Task launch for StepId=5216715.0 failed on node learnfair0537: Unable to satisfy cpu bind request + srun: error: Application launch failed: Unable to satisfy cpu bind request + srun: Job step aborted + + See https://www.mail-archive.com/slurm-users@lists.schedmd.com/msg09157.html for more details. + """ + + removed_env_vars = {} + for key in list(os.environ.keys()): + if not key.startswith("SLURM_"): + continue + removed_env_vars[key] = os.environ.pop(key) + + log.debug( + f"Removed environment variables before launching new SLURM job: {list(removed_env_vars.keys())}" + ) + try: + yield + finally: + os.environ.update(removed_env_vars) + log.debug( + f"Restored environment variables after launching new SLURM job: {list(removed_env_vars.keys())}" + ) + + +@contextmanager +def remove_wandb_environment_variables(): + """ + Similar to above, but removes all "WANDB_" environment variables. + """ + + removed_env_vars = {} + for key in list(os.environ.keys()): + if not key.startswith("WANDB_"): + continue + removed_env_vars[key] = os.environ.pop(key) + + log.debug( + f"Removed environment variables before launching new SLURM job: {list(removed_env_vars.keys())}" + ) + try: + yield + finally: + os.environ.update(removed_env_vars) + log.debug( + f"Restored environment variables after launching new SLURM job: {list(removed_env_vars.keys())}" + ) + + +@contextmanager +def set_additional_env_vars(additional_env_vars: dict[str, str] | None = None): + """ + Set additional environment variables for the run. + Newly set environment variables will be removed after the run is finished. + Existing environment variables will be restored to their original values after the run is finished. + """ + if additional_env_vars is None: + additional_env_vars = {} + + removed_env_vars = {} + for key, value in additional_env_vars.items(): + removed_env_vars[key] = os.environ.pop(key, None) + os.environ[key] = value + + log.debug( + f"Set additional environment variables for the run: {list(additional_env_vars.keys())}" + ) + try: + yield + finally: + for key in additional_env_vars.keys(): + if removed_env_vars[key] is None: + del os.environ[key] + else: + os.environ[key] = removed_env_vars[key] + log.debug( + f"Restored environment variables after launching new SLURM job: {list(additional_env_vars.keys())}" + ) diff --git a/src/jmp/lightning/util/log_batch_info.py b/src/jmp/lightning/util/log_batch_info.py new file mode 100644 index 0000000..25d41ff --- /dev/null +++ b/src/jmp/lightning/util/log_batch_info.py @@ -0,0 +1,57 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import wraps +from logging import getLogger + +from lightning.pytorch import LightningModule + +from ..exception import SkipBatch, TrainingError + +log = getLogger(__name__) + + +def _wrap_fn(module: LightningModule, fn_name: str): + old_step = getattr(module, fn_name).__func__ + + @wraps(old_step) + def new_step(module: LightningModule, batch, batch_idx, *args, **kwargs): + try: + return old_step(module, batch, batch_idx, *args, **kwargs) + except BaseException as e: + if isinstance(e, SkipBatch): + # we don't need to handle this case + raise e + + # we need to re-raise the exception with more information + raise TrainingError( + e, + batch_idx=batch_idx, + batch=batch, + epoch=module.current_epoch, + global_step=module.global_step, + training_fn=fn_name, + ) from e + + setattr(module, fn_name, new_step.__get__(module)) + log.info(f"Wrapped {fn_name} for log_batch_info") + + +def wrap_lightning_module(module: LightningModule): + log.info( + "Wrapping training_step/validation_step/test_step/predict_step for log_batch_info" + ) + + _wrap_fn(module, "training_step") + _wrap_fn(module, "validation_step") + _wrap_fn(module, "test_step") + _wrap_fn(module, "predict_step") + + log.info( + "Wrapped training_step/validation_step/test_step/predict_step for log_batch_info" + ) diff --git a/src/jmp/lightning/util/notebook/yaml.py b/src/jmp/lightning/util/notebook/yaml.py new file mode 100644 index 0000000..b735c3a --- /dev/null +++ b/src/jmp/lightning/util/notebook/yaml.py @@ -0,0 +1,66 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from IPython.display import HTML, Javascript, display + + +class YamlEditor: + def __init__(self, yaml_str): + self.yaml_str = yaml_str + + def _ipython_display_(self): + # Set up CodeMirror editor options + element_id = f"yaml-editor-{id(self)}" + + # Generate HTML for the CodeMirror editor + # + html = f""" + + + + + """ + _ = display(HTML(html)) + + _ = display( + Javascript( + """ + require.config({ + packages: [{ + name: "codemirror", + location: "https://cdnjs.cloudflare.com/ajax/libs/codemirror/5.58.3", + main: "codemirror" + }], + map: { + "*": { + "codemirror/lib/codemirror": "codemirror", + } + } + }); + + require([ + "codemirror", + "codemirror/mode/yaml/yaml", // for yaml mode + "codemirror/addon/fold/foldcode", // for code folding + "codemirror/addon/fold/foldgutter", // for code folding + "codemirror/addon/fold/brace-fold", // for code folding + "codemirror/addon/fold/comment-fold", // for code folding + "codemirror/addon/fold/indent-fold", // for code folding + ], function(CodeMirror) { + var editor = CodeMirror.fromTextArea(document.getElementById("%(element_id)s"), { + mode: "yaml", + lineNumbers: true, + readOnly: true, + foldGutter: true, + gutters: ["CodeMirror-linenumbers", "CodeMirror-foldgutter"], + }); + }); + """ + % {"element_id": element_id} + ) + ) diff --git a/src/jmp/lightning/util/pretty_print.py b/src/jmp/lightning/util/pretty_print.py new file mode 100644 index 0000000..a9a6f69 --- /dev/null +++ b/src/jmp/lightning/util/pretty_print.py @@ -0,0 +1,41 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + + +class PrettyPrintMixin: + def __repr__(self): + base_cls_name = self.__class__.__name__ + prop_list = [] + properties = {} + properties.update( + { + k: v + for k in dir(self) + if (attr := getattr(self.__class__, k, None)) is not None + and isinstance(attr, property) + and (v := getattr(self, k, None)) is not None + } + ) + properties.update(self.__dict__) + for k, v in properties.items(): + if isinstance(v, (int, float, str)): + prop_list.append(f"{k}={v}") + elif isinstance(v, (np.ndarray, torch.Tensor)): + numel = v.numel() if isinstance(v, torch.Tensor) else np.prod(v.shape) + if numel == 1: + prop_list.append(f"{k}={v.item()}") + else: + shape = list(v.shape) + prop_list.append(f"{k}={shape}") + else: + prop_list.append(f"{k}={type(v)}") + prop_repr = ", \n".join(f"\t{p}" for p in prop_list) + return f"{base_cls_name}(\n{prop_repr}\n)" diff --git a/src/jmp/lightning/util/seed.py b/src/jmp/lightning/util/seed.py new file mode 100644 index 0000000..4b50e55 --- /dev/null +++ b/src/jmp/lightning/util/seed.py @@ -0,0 +1,38 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from contextlib import contextmanager +from logging import getLogger + +import lightning_fabric.utilities.seed as LS + +log = getLogger(__name__) + + +def seed_everything(seed: int, *, workers: bool = False): + seed = LS.seed_everything(seed, workers=workers) + log.critical(f"Set global seed to {seed}.") + return seed + + +def reset_seed(): + LS.reset_seed() + log.critical("Reset global seed.") + + +@contextmanager +def seed_context(seed: int | None, *, workers: bool = False): + if seed is None: + seed = LS._select_seed_randomly() + log.warning(f"No seed provided, using random seed {seed}.") + + try: + seed = seed_everything(seed, workers=workers) + yield + finally: + reset_seed() diff --git a/src/jmp/lightning/util/singleton.py b/src/jmp/lightning/util/singleton.py new file mode 100644 index 0000000..fb1a15a --- /dev/null +++ b/src/jmp/lightning/util/singleton.py @@ -0,0 +1,94 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import Any + +from typing_extensions import Self, TypeVar, override + +log = getLogger(__name__) + + +class Singleton: + singleton_key = "_singleton_instance" + + @classmethod + def get(cls) -> Self | None: + return getattr(cls, cls.singleton_key, None) + + @classmethod + def set(cls, instance: Self) -> None: + if cls.get() is not None: + log.warning(f"{cls.__qualname__} instance is already set") + + setattr(cls, cls.singleton_key, instance) + + @classmethod + def reset(cls) -> None: + if cls.get() is not None: + delattr(cls, cls.singleton_key) + + @classmethod + def register(cls, instance: Self) -> None: + cls.set(instance) + + @classmethod + def instance(cls) -> Self: + instance = cls.get() + if instance is None: + raise RuntimeError(f"{cls.__qualname__} instance is not set") + + return instance + + @override + def __init_subclass__(cls, *args, **kwargs) -> None: + super().__init_subclass__(*args, **kwargs) + + cls.reset() + + +T = TypeVar("T", infer_variance=True) + + +class Registry: + _registry: dict[type, Any] = {} + + @staticmethod + def register(cls_: type[T], instance: T): + if not isinstance(instance, cls_): + raise ValueError(f"{instance} is not an instance of {cls_.__qualname__}") + + if cls_ in Registry._registry: + raise ValueError(f"{cls_.__qualname__} is already registered") + + Registry._registry[cls_] = instance + + @staticmethod + def try_get(cls_: type[T]) -> T | None: + return Registry._registry.get(cls_) + + @staticmethod + def get(cls_: type[T]) -> T: + instance = Registry.try_get(cls_) + if instance is None: + raise ValueError(f"{cls_.__qualname__} is not registered") + + return instance + + @staticmethod + def instance(cls_: type[T]) -> T: + return Registry.get(cls_) + + @staticmethod + def reset(cls_: type[T]): + if cls_ in Registry._registry: + del Registry._registry[cls_] + + @staticmethod + def reset_all(): + Registry._registry.clear() diff --git a/src/jmp/lightning/util/skip_batch.py b/src/jmp/lightning/util/skip_batch.py new file mode 100644 index 0000000..e33e904 --- /dev/null +++ b/src/jmp/lightning/util/skip_batch.py @@ -0,0 +1,66 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import wraps +from logging import getLogger +from typing import TYPE_CHECKING + +from typing_extensions import TypeVar + +from ..exception import SkipBatch +from ..model.config import BaseConfig + +if TYPE_CHECKING: + from ..model.base import LightningModuleBase + + +log = getLogger(__name__) + +THparams = TypeVar("THparams", bound=BaseConfig, infer_variance=True) + + +def _wrap_fn( + module: "LightningModuleBase[THparams]", + fn_name: str, + is_training_step: bool = False, +): + old_step = getattr(module, fn_name).__func__ + + @wraps(old_step) + def new_step( + module: "LightningModuleBase[THparams]", batch, batch_idx, *args, **kwargs + ): + try: + return old_step(module, batch, batch_idx, *args, **kwargs) + except SkipBatch as e: + log.info( + f"[{fn_name}] @ [step={module.global_step}, batch={batch_idx}]: Skipping batch due to SkipBatch exception: {e}" + ) + + if is_training_step: + return module.skip_batch_training_step( + batch, batch_idx, *args, **kwargs + ) + + setattr(module, fn_name, new_step.__get__(module)) + log.info(f"Wrapped {fn_name} for skip_batch_exception") + + +def wrap_lightning_module(module: "LightningModuleBase[THparams]"): + log.info( + "Wrapping training_step/validation_step/test_step/predict_step for skip_batch_exception" + ) + + _wrap_fn(module, "training_step", is_training_step=True) + _wrap_fn(module, "validation_step", is_training_step=False) + _wrap_fn(module, "test_step", is_training_step=False) + _wrap_fn(module, "predict_step", is_training_step=False) + + log.info( + "Wrapped training_step/validation_step/test_step/predict_step for skip_batch_exception" + ) diff --git a/src/jmp/lightning/util/slurm.py b/src/jmp/lightning/util/slurm.py new file mode 100644 index 0000000..c4e1f45 --- /dev/null +++ b/src/jmp/lightning/util/slurm.py @@ -0,0 +1,109 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import getpass +from datetime import timedelta +from logging import getLogger +from pathlib import Path + +from submitit import AutoExecutor + +from .snapshot import snapshot_modules + +log = getLogger(__name__) + + +def create_executor( + *, + tasks_per_node: int, + cpus_per_task: int, + gpus_per_task: int, + nodes: int, + partition: str, + timeout: timedelta = timedelta(hours=72), + memory: int = 480, + email: str | None = None, + constraints: list[str] | None = None, + volta16gb: bool | None = None, + volta32gb: bool | None = None, + slurm_additional_parameters: dict[str, str] | None = None, + slurm_setup: list[str] | None = None, + snapshot: bool | Path, + snapshot_base: Path | None = None, + env: dict[str, str] | None = None, + job_name: str = "jmplightning", + snapshot_env_name: str = "LL_SNAPSHOT", +): + if volta16gb and volta32gb: + raise ValueError("Cannot have both volta16gb and volta32gb") + elif volta16gb is None and volta32gb is None: + volta16gb = False + volta32gb = True + + if volta16gb is None: + volta16gb = False + if volta32gb is None: + volta32gb = False + + if snapshot_base is None: + current_user = getpass.getuser() + snapshot_base = Path(f"/checkpoint/{current_user}/ll_snapshots/") + + if snapshot is True: + snapshot = snapshot_modules(snapshot_base, ["jmp", "submitit"]).absolute() + + base_path = Path(".") / "slurm_logs" + base_path.mkdir(exist_ok=True, parents=True) + + additional_parameters = {} + if not constraints: + constraints = [] + if email: + additional_parameters.update({"mail_user": email, "mail_type": "FAIL"}) + if volta16gb: + # additional_parameters.update({"constraint": "volta16gb"}) + constraints.append("volta16gb") + if volta32gb: + # additional_parameters.update({"constraint": "volta32gb"}) + constraints.append("volta32gb") + if slurm_additional_parameters: + additional_parameters.update(slurm_additional_parameters) + + # add constraints from slurm_additional_parameters + if (constraint := additional_parameters.pop("constraint", None)) is not None: + constraints.append(constraint) + + # remove duplicates + constraints = list(set(constraints)) + if constraints: + additional_parameters.update({"constraint": ",".join(constraints)}) + + setup = [] + if env: + setup.extend(f"export {k}={v}" for k, v in env.items()) + if slurm_setup: + setup.extend(slurm_setup) + if snapshot: + snapshot_str = str(snapshot.resolve().absolute()) + setup.append(f"export {snapshot_env_name}={snapshot_str}") + setup.append(f"export PYTHONPATH={snapshot_str}:$PYTHONPATH") + + executor = AutoExecutor(folder=base_path / "%j") + executor.update_parameters( + name=job_name, + mem_gb=memory, + timeout_min=int(timeout.total_seconds() / 60), + cpus_per_task=cpus_per_task, + tasks_per_node=tasks_per_node, + nodes=nodes, + slurm_gpus_per_task=gpus_per_task, + slurm_partition=partition, + slurm_additional_parameters=additional_parameters, + slurm_setup=setup, + ) + return executor diff --git a/src/jmp/lightning/util/snapshot.py b/src/jmp/lightning/util/snapshot.py new file mode 100644 index 0000000..036fe8b --- /dev/null +++ b/src/jmp/lightning/util/snapshot.py @@ -0,0 +1,109 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import importlib.util +import subprocess +import uuid +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from logging import getLogger +from pathlib import Path +from typing import Sequence + +log = getLogger(__name__) + + +@dataclass(kw_only=True) +class SnapshotInformation: + snapshot_dir: Path + moved_modules: dict[str, list[tuple[Path, Path]]] + + +def _copy(source: Path, location: Path): + ignored_files = ( + subprocess.check_output( + [ + "git", + "-C", + str(source), + "ls-files", + "--exclude-standard", + "-oi", + "--directory", + ] + ) + .decode("utf-8") + .splitlines() + ) + + # run rsync with .git folder and `ignored_files` excluded + _ = subprocess.run( + [ + "rsync", + "-a", + "--exclude", + ".git", + *(f"--exclude={file}" for file in ignored_files), + str(source), + str(location), + ], + check=True, + ) + + +def snapshot_modules( + base: str | Path, + modules: Sequence[str], + *, + id: str | None = None, + add_date_to_dir: bool = True, + error_on_existing: bool = True, +): + if not id: + id = str(uuid.uuid4()) + + snapshot_dir = Path(base) + if add_date_to_dir: + snapshot_dir = snapshot_dir / datetime.now().strftime("%Y-%m-%d") + snapshot_dir = snapshot_dir / id + snapshot_dir.mkdir(parents=True, exist_ok=not error_on_existing) + + log.critical(f"Snapshotting to {snapshot_dir}") + + moved_modules = defaultdict[str, list[tuple[Path, Path]]](list) + for module in modules: + spec = importlib.util.find_spec(module) + if spec is None: + log.warning(f"Module {module} not found") + continue + + assert ( + spec.submodule_search_locations + and len(spec.submodule_search_locations) == 1 + ), f"Could not find module {module} in a single location." + location = Path(spec.submodule_search_locations[0]) + assert ( + location.is_dir() + ), f"Module {module} has a non-directory location {location}" + + (*parent_modules, module_name) = module.split(".") + + destination = snapshot_dir + for part in parent_modules: + destination = destination / part + destination.mkdir(parents=True, exist_ok=True) + (destination / "__init__.py").touch(exist_ok=True) + + _copy(location, destination) + + destination = destination / module_name + log.info(f"Moved {location} to {destination} for {module=}") + moved_modules[module].append((location, destination)) + + return snapshot_dir diff --git a/src/jmp/lightning/util/typed/__init__.py b/src/jmp/lightning/util/typed/__init__.py new file mode 100644 index 0000000..c352ee1 --- /dev/null +++ b/src/jmp/lightning/util/typed/__init__.py @@ -0,0 +1,12 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .module_dict import TypedModuleDict +from .module_list import TypedModuleList + +__all__ = ["TypedModuleDict", "TypedModuleList"] diff --git a/src/jmp/lightning/util/typed/module_dict.py b/src/jmp/lightning/util/typed/module_dict.py new file mode 100644 index 0000000..7d938fd --- /dev/null +++ b/src/jmp/lightning/util/typed/module_dict.py @@ -0,0 +1,68 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Generic, Iterable, Mapping + +import torch.nn as nn +from typing_extensions import TypeVar + +TModule = TypeVar("TModule", bound=nn.Module, infer_variance=True) + + +class TypedModuleDict(nn.Module, Generic[TModule]): + def __init__( + self, + modules: Mapping[str, TModule] | None = None, + key_prefix: str = "_typed_moduledict_", + # we use a key prefix to avoid attribute name collisions + # (which is a common issue in nn.ModuleDict as it uses `__setattr__` to set the modules) + ): + super().__init__() + + self.key_prefix = key_prefix + self._module_dict = nn.ModuleDict( + {self._with_prefix(k): v for k, v in modules.items()} + ) + + def _with_prefix(self, key: str) -> str: + return f"{self.key_prefix}{key}" + + def _remove_prefix(self, key: str) -> str: + assert key.startswith( + self.key_prefix + ), f"{key} does not start with {self.key_prefix}" + return key[len(self.key_prefix) :] + + def __setitem__(self, key: str, module: TModule) -> None: + key = self._with_prefix(key) + return self._module_dict.__setitem__(key, module) + + def __getitem__(self, key: str) -> TModule: + key = self._with_prefix(key) + return self._module_dict.__getitem__(key) # type: ignore + + def update(self, modules: Mapping[str, TModule]) -> None: + return self._module_dict.update( + {self._with_prefix(k): v for k, v in modules.items()} + ) + + def get(self, key: str) -> TModule | None: + key = self._with_prefix(key) + return self._module_dict._modules.get(key) + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ModuleDict keys.""" + return [self._remove_prefix(k) for k in self._module_dict.keys()] + + def items(self) -> Iterable[tuple[str, TModule]]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return [(self._remove_prefix(k), v) for k, v in self._module_dict.items()] + + def values(self) -> Iterable[TModule]: + r"""Return an iterable of the ModuleDict values.""" + return self._module_dict.values() diff --git a/src/jmp/lightning/util/typed/module_list.py b/src/jmp/lightning/util/typed/module_list.py new file mode 100644 index 0000000..5871145 --- /dev/null +++ b/src/jmp/lightning/util/typed/module_list.py @@ -0,0 +1,57 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Generic, Iterable, Iterator, Optional, TypeVar, overload + +import torch.nn as nn +from typing_extensions import override + +TModule = TypeVar("TModule") + + +class TypedModuleList(nn.ModuleList, Generic[TModule]): + def __init__(self, modules: Optional[Iterable[TModule]] = None) -> None: + super().__init__(modules) + + @overload + def __getitem__(self, idx: int) -> TModule: ... + + @overload + def __getitem__(self, idx: slice) -> "TypedModuleList[TModule]": ... + + @override + def __getitem__(self, idx: int | slice) -> TModule | "TypedModuleList[TModule]": + return super().__getitem__(idx) # type: ignore + + @override + def __setitem__(self, idx: int, module: TModule) -> None: + return super().__setitem__(idx, module) + + @override + def __iter__(self) -> Iterator[TModule]: + return super().__iter__() # type: ignore + + @override + def __iadd__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": + return super().__iadd__(modules) # type: ignore + + @override + def __add__(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": + return super().__add__(modules) # type: ignore + + @override + def insert(self, idx: int, module: TModule) -> None: + return super().insert(idx, module) # type: ignore + + @override + def append(self, module: TModule) -> "TypedModuleList[TModule]": + return super().append(module) # type: ignore + + @override + def extend(self, modules: Iterable[TModule]) -> "TypedModuleList[TModule]": + return super().extend(modules) # type: ignore diff --git a/src/jmp/lightning/util/typing_utils.py b/src/jmp/lightning/util/typing_utils.py new file mode 100644 index 0000000..ea7978d --- /dev/null +++ b/src/jmp/lightning/util/typing_utils.py @@ -0,0 +1,67 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import TYPE_CHECKING, Any, Callable, Concatenate, Type, cast + +from typing_extensions import ParamSpec, TypeVar + +P = ParamSpec("P") +TSelf = TypeVar("TSelf", infer_variance=True) +TParam = TypeVar("TParam", infer_variance=True) +TReturn = TypeVar("TReturn", infer_variance=True) + + +def copy_args( + kwargs_call: Callable[P, Any], + *, + return_type: Type[TReturn], +) -> Callable[[Callable[..., TReturn]], Callable[P, TReturn]]: + """ + Copies the type annotations from one function to another. + """ + + def return_func(func: Callable[..., TReturn]): + return cast(Callable[P, TReturn], func) + + return return_func + + +def copy_method_with_param( + kwargs_call: Callable[Concatenate[TSelf, P], Any], + *, + param_type: Type[TParam], + return_type: Type[TReturn], +) -> Callable[ + [Callable[..., TReturn]], Callable[Concatenate[TSelf, TParam, P], TReturn] +]: + """ + Copies the type annotations from one method to another, + but adds a new parameter to the beginning. + """ + + def return_func(func: Callable[..., TReturn]): + return cast(Callable[Concatenate[TSelf, TParam, P], TReturn], func) + + return return_func + + +TBase = TypeVar("TBase") + + +def mixin_base_type(base_class: Type[TBase]) -> Type[TBase]: + """ + Useful function to make mixins with baseclass typehint + + ``` + class ReadonlyMixin(mixin_base_type(BaseAdmin))): + ... + ``` + """ + if TYPE_CHECKING: + return base_class + return object diff --git a/src/jmp/lightning/util/util.py b/src/jmp/lightning/util/util.py new file mode 100644 index 0000000..44a79c9 --- /dev/null +++ b/src/jmp/lightning/util/util.py @@ -0,0 +1,54 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import importlib +from typing import Callable, Dict, List + +import torch + + +def convert_list_of_dicts( + data_list: List[Dict[str, torch.Tensor]], +) -> Dict[str, List[torch.Tensor]]: + return { + k: [d[k] if torch.is_tensor(d[k]) else torch.tensor(d[k]) for d in data_list] + for k in data_list[0].keys() + } + + +def compose(*transforms: Callable): + def composed(x): + for transform in transforms: + x = transform(x) + return x + + return composed + + +def get_absolute_mapping(name: str): + # in this case, the `name` should be the fully qualified name of the class + # e.g., `ocpmodels.tasks.base_task.BaseTask` + # we can use importlib to get the module (e.g., `ocpmodels.tasks.base_task`) + # and then import the class (e.g., `BaseTask`) + + module_name = ".".join(name.split(".")[:-1]) + class_name = name.split(".")[-1] + + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError as e: + raise RuntimeError( + f"Could not import module {module_name=} for class {name=}" + ) from e + + try: + return getattr(module, class_name) + except AttributeError as e: + raise RuntimeError( + f"Could not import class {class_name=} from module {module_name=}" + ) from e diff --git a/src/jmp/models/gemnet/__init__.py b/src/jmp/models/gemnet/__init__.py new file mode 100644 index 0000000..7e1665b --- /dev/null +++ b/src/jmp/models/gemnet/__init__.py @@ -0,0 +1,7 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" diff --git a/src/jmp/models/gemnet/backbone.py b/src/jmp/models/gemnet/backbone.py new file mode 100644 index 0000000..078a4e9 --- /dev/null +++ b/src/jmp/models/gemnet/backbone.py @@ -0,0 +1,776 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import TypedDict + +import torch +import torch.nn as nn +from jmp.lightning.util.typed import TypedModuleList +from torch_geometric.data.data import BaseData +from torch_scatter import segment_coo + +from ...modules.scaling.compat import load_scales_compat +from ...utils.goc_graph import graphs_from_batch +from .bases import Bases, BasesOutput +from .config import BackboneConfig, BasesConfig +from .interaction_indices import get_mixed_triplets, get_quadruplets, get_triplets +from .layers.atom_update_block import OutputBlock +from .layers.base_layers import Dense, ResidualLayer +from .layers.force_scaler import ForceScaler +from .layers.interaction_block import InteractionBlock +from .utils import ( + get_angle, + get_edge_id, + get_inner_idx, + inner_product_clamped, + repeat_blocks, +) + + +class GOCBackboneOutput(TypedDict): + idx_s: torch.Tensor + idx_t: torch.Tensor + V_st: torch.Tensor + D_st: torch.Tensor + + energy: torch.Tensor + forces: torch.Tensor + + +class FinalMLP(nn.Module): + def __init__( + self, + *, + emb_size: int, + num_blocks: int, + num_global_out_layers: int, + activation: str | None = None, + dropout: float | None = None, + ): + super().__init__() + + out_mlp = [ + Dense( + emb_size * (num_blocks + 1), + emb_size, + activation=activation, + dropout=dropout, + ) + ] + out_mlp += [ + ResidualLayer(emb_size, activation=activation, dropout=dropout) + for _ in range(num_global_out_layers) + ] + self.out_mlp = nn.Sequential(*out_mlp) + + def forward( + self, + x: torch.Tensor, + *, + data: BaseData, + edge_index: torch.Tensor, + ) -> torch.Tensor: + return self.out_mlp(x) + + +class GemNetOCBackbone(nn.Module): + """ + Arguments + --------- + num_atoms (int): Unused argument + bond_feat_dim (int): Unused argument + num_targets: int + Number of prediction targets. + + num_spherical: int + Controls maximum frequency. + num_radial: int + Controls maximum frequency. + num_blocks: int + Number of building blocks to be stacked. + + emb_size_atom: int + Embedding size of the atoms. + emb_size_edge: int + Embedding size of the edges. + emb_size_trip_in: int + (Down-projected) embedding size of the quadruplet edge embeddings + before the bilinear layer. + emb_size_trip_out: int + (Down-projected) embedding size of the quadruplet edge embeddings + after the bilinear layer. + emb_size_quad_in: int + (Down-projected) embedding size of the quadruplet edge embeddings + before the bilinear layer. + emb_size_quad_out: int + (Down-projected) embedding size of the quadruplet edge embeddings + after the bilinear layer. + emb_size_aint_in: int + Embedding size in the atom interaction before the bilinear layer. + emb_size_aint_out: int + Embedding size in the atom interaction after the bilinear layer. + emb_size_rbf: int + Embedding size of the radial basis transformation. + emb_size_cbf: int + Embedding size of the circular basis transformation (one angle). + emb_size_sbf: int + Embedding size of the spherical basis transformation (two angles). + + num_before_skip: int + Number of residual blocks before the first skip connection. + num_after_skip: int + Number of residual blocks after the first skip connection. + num_concat: int + Number of residual blocks after the concatenation. + num_atom: int + Number of residual blocks in the atom embedding blocks. + num_output_afteratom: int + Number of residual blocks in the output blocks + after adding the atom embedding. + num_atom_emb_layers: int + Number of residual blocks for transforming atom embeddings. + num_global_out_layers: int + Number of final residual blocks before the output. + + regress_forces: bool + Whether to predict forces. Default: True + direct_forces: bool + If True predict forces based on aggregation of interatomic directions. + If False predict forces based on negative gradient of energy potential. + use_pbc: bool + Whether to use periodic boundary conditions. + scale_backprop_forces: bool + Whether to scale up the energy and then scales down the forces + to prevent NaNs and infs in backpropagated forces. + + rbf: dict + Name and hyperparameters of the radial basis function. + rbf_spherical: dict + Name and hyperparameters of the radial basis function used as part of the + circular and spherical bases. + Optional. Uses rbf per default. + envelope: dict + Name and hyperparameters of the envelope function. + cbf: dict + Name and hyperparameters of the circular basis function. + sbf: dict + Name and hyperparameters of the spherical basis function. + extensive: bool + Whether the output should be extensive (proportional to the number of atoms) + forces_coupled: bool + If True, enforce that |F_st| = |F_ts|. No effect if direct_forces is False. + output_init: str + Initialization method for the final dense layer. + activation: str + Name of the activation function. + scale_file: str + Path to the pytorch file containing the scaling factors. + + quad_interaction: bool + Whether to use quadruplet interactions (with dihedral angles) + atom_edge_interaction: bool + Whether to use atom-to-edge interactions + edge_atom_interaction: bool + Whether to use edge-to-atom interactions + atom_interaction: bool + Whether to use atom-to-atom interactions + + scale_basis: bool + Whether to use a scaling layer in the raw basis function for better + numerical stability. + qint_tags: list + Which atom tags to use quadruplet interactions for. + 0=sub-surface bulk, 1=surface, 2=adsorbate atoms. + """ + + qint_tags: torch.Tensor + + def __init__( + self, + config: BackboneConfig, + *, + num_targets: int, + num_spherical: int, + num_radial: int, + num_blocks: int, + emb_size_atom: int, + emb_size_edge: int, + emb_size_trip_in: int, + emb_size_trip_out: int, + emb_size_quad_in: int, + emb_size_quad_out: int, + emb_size_aint_in: int, + emb_size_aint_out: int, + emb_size_rbf: int, + emb_size_cbf: int, + emb_size_sbf: int, + num_before_skip: int, + num_after_skip: int, + num_concat: int, + num_atom: int, + num_output_afteratom: int, + num_atom_emb_layers: int = 0, + num_global_out_layers: int = 2, + regress_energy: bool = True, + regress_forces: bool = True, + direct_forces: bool = False, + use_pbc: bool = True, + scale_backprop_forces: bool = False, + rbf: dict = {"name": "gaussian"}, + rbf_spherical: dict | None = None, + envelope: dict = {"name": "polynomial", "exponent": 5}, + cbf: dict = {"name": "spherical_harmonics"}, + sbf: dict = {"name": "spherical_harmonics"}, + extensive: bool = True, + forces_coupled: bool = False, + activation: str = "silu", + quad_interaction: bool = False, + atom_edge_interaction: bool = False, + edge_atom_interaction: bool = False, + atom_interaction: bool = False, + scale_basis: bool = False, + qint_tags: list = [0, 1, 2], + num_elements: int = 120, + otf_graph: bool = False, + scale_file: str | None = None, + absolute_rbf_cutoff: float | None = None, + **kwargs, + ): + super().__init__() + + self.shared_parameters: list[tuple[nn.Parameter, int]] = [] + + print("Unrecognized arguments: ", kwargs.keys()) + + self.config = config + + self.num_targets = num_targets + assert num_blocks > 0 + self.num_blocks = num_blocks + self.extensive = extensive + self.num_elements = num_elements + + self.atom_edge_interaction = atom_edge_interaction + self.edge_atom_interaction = edge_atom_interaction + self.atom_interaction = atom_interaction + self.quad_interaction = quad_interaction + self.otf_graph = otf_graph + + self.register_buffer("qint_tags", torch.tensor(qint_tags), persistent=False) + + if not rbf_spherical: + rbf_spherical = rbf + + self.use_pbc = use_pbc + + self.direct_forces = direct_forces + self.forces_coupled = forces_coupled + self.regress_forces = regress_forces + self.regress_energy = regress_energy + self.force_scaler = ForceScaler(enabled=scale_backprop_forces) + + self.bases = Bases(BasesConfig.from_backbone_config(self.config)) + if not self.config.unique_basis_per_layer: + self.shared_parameters.extend(self.bases.shared_parameters) + else: + self.per_layer_bases = TypedModuleList( + [ + Bases(BasesConfig.from_backbone_config(self.config)) + for _ in range(self.num_blocks) + ] + ) + + # Embedding blocks + # self.atom_emb = AtomEmbedding(emb_size_atom, num_elements) + # self.edge_emb = EdgeEmbedding( + # emb_size_atom, num_radial, emb_size_edge, activation=activation + # ) + + # Interaction Blocks + int_blocks = [] + for _ in range(num_blocks): + int_blocks.append( + InteractionBlock( + emb_size_atom=emb_size_atom, + emb_size_edge=emb_size_edge, + emb_size_trip_in=emb_size_trip_in, + emb_size_trip_out=emb_size_trip_out, + emb_size_quad_in=emb_size_quad_in, + emb_size_quad_out=emb_size_quad_out, + emb_size_a2a_in=emb_size_aint_in, + emb_size_a2a_out=emb_size_aint_out, + emb_size_rbf=emb_size_rbf, + emb_size_cbf=emb_size_cbf, + emb_size_sbf=emb_size_sbf, + num_before_skip=num_before_skip, + num_after_skip=num_after_skip, + num_concat=num_concat, + num_atom=num_atom, + num_atom_emb_layers=num_atom_emb_layers, + quad_interaction=quad_interaction, + atom_edge_interaction=atom_edge_interaction, + edge_atom_interaction=edge_atom_interaction, + atom_interaction=atom_interaction, + activation=activation, + dropout=self.config.dropout, + ) + ) + self.int_blocks = nn.ModuleList(int_blocks) + out_blocks = [] + for _ in range(num_blocks + 1): + out_blocks.append( + OutputBlock( + emb_size_atom=emb_size_atom, + emb_size_edge=emb_size_edge, + emb_size_rbf=emb_size_rbf, + nHidden=num_atom, + nHidden_afteratom=num_output_afteratom, + activation=activation, + direct_forces=direct_forces, + edge_dropout=self.config.edge_dropout, + dropout=self.config.dropout, + ) + ) + self.out_blocks = nn.ModuleList(out_blocks) + + # out_mlp_E = [ + # Dense( + # emb_size_atom * (num_blocks + 1), + # emb_size_atom, + # activation=activation, + # ) + # ] + # out_mlp_E += [ + # ResidualLayer( + # emb_size_atom, + # activation=activation, + # ) + # for _ in range(num_global_out_layers) + # ] + # self.out_mlp_E = nn.Sequential(*out_mlp_E) + self.out_mlp_E = FinalMLP( + emb_size=emb_size_atom, + num_blocks=num_blocks, + num_global_out_layers=num_global_out_layers, + activation=activation, + ) + if direct_forces: + # out_mlp_F = [ + # Dense( + # emb_size_edge * (num_blocks + 1), + # emb_size_edge, + # activation=activation, + # ) + # ] + # out_mlp_F += [ + # ResidualLayer( + # emb_size_edge, + # activation=activation, + # ) + # for _ in range(num_global_out_layers) + # ] + # self.out_mlp_F = nn.Sequential(*out_mlp_F) + self.out_mlp_F = FinalMLP( + emb_size=emb_size_edge, + num_blocks=num_blocks, + num_global_out_layers=num_global_out_layers, + activation=activation, + dropout=self.config.dropout, + ) + + load_scales_compat(self, scale_file) + + def calculate_quad_angles( + self, + V_st, + V_qint_st, + quad_idx, + ): + """Calculate angles for quadruplet-based message passing. + + Arguments + --------- + V_st: Tensor, shape = (nAtoms, 3) + Normalized directions from s to t + V_qint_st: Tensor, shape = (nAtoms, 3) + Normalized directions from s to t for the quadruplet + interaction graph + quad_idx: dict of torch.Tensor + Indices relevant for quadruplet interactions. + + Returns + ------- + cosφ_cab: Tensor, shape = (num_triplets_inint,) + Cosine of angle between atoms c -> a <- b. + cosφ_abd: Tensor, shape = (num_triplets_qint,) + Cosine of angle between atoms a -> b -> d. + angle_cabd: Tensor, shape = (num_quadruplets,) + Dihedral angle between atoms c <- a-b -> d. + """ + # ---------------------------------- d -> b -> a ---------------------------------- # + V_ba = V_qint_st[quad_idx["triplet_in"]["out"]] + # (num_triplets_qint, 3) + V_db = V_st[quad_idx["triplet_in"]["in"]] + # (num_triplets_qint, 3) + cosφ_abd = inner_product_clamped(V_ba, V_db) + # (num_triplets_qint,) + + # Project for calculating dihedral angle + # Cross product is the same as projection, just 90° rotated + V_db_cross = torch.cross(V_db, V_ba, dim=-1) # a - b -| d + V_db_cross = V_db_cross[quad_idx["trip_in_to_quad"]] + # (num_quadruplets,) + + # --------------------------------- c -> a <- b ---------------------------------- # + V_ca = V_st[quad_idx["triplet_out"]["out"]] # (num_triplets_in, 3) + V_ba = V_qint_st[quad_idx["triplet_out"]["in"]] # (num_triplets_in, 3) + cosφ_cab = inner_product_clamped(V_ca, V_ba) # (n4Triplets,) + + # Project for calculating dihedral angle + # Cross product is the same as projection, just 90° rotated + V_ca_cross = torch.cross(V_ca, V_ba, dim=-1) # c |- a - b + V_ca_cross = V_ca_cross[quad_idx["trip_out_to_quad"]] + # (num_quadruplets,) + + # -------------------------------- c -> a - b <- d -------------------------------- # + half_angle_cabd = get_angle(V_ca_cross, V_db_cross) + # (num_quadruplets,) + angle_cabd = half_angle_cabd + # Ignore parity and just use the half angle. + + return cosφ_cab, cosφ_abd, angle_cabd + + def select_symmetric_edges(self, tensor, mask, reorder_idx, opposite_neg): + """Use a mask to remove values of removed edges and then + duplicate the values for the correct edge direction. + + Arguments + --------- + tensor: torch.Tensor + Values to symmetrize for the new tensor. + mask: torch.Tensor + Mask defining which edges go in the correct direction. + reorder_idx: torch.Tensor + Indices defining how to reorder the tensor values after + concatenating the edge values of both directions. + opposite_neg: bool + Whether the edge in the opposite direction should use the + negative tensor value. + + Returns + ------- + tensor_ordered: torch.Tensor + A tensor with symmetrized values. + """ + # Mask out counter-edges + tensor_directed = tensor[mask] + # Concatenate counter-edges after normal edges + sign = 1 - 2 * opposite_neg + tensor_cat = torch.cat([tensor_directed, sign * tensor_directed]) + # Reorder everything so the edges of every image are consecutive + tensor_ordered = tensor_cat[reorder_idx] + return tensor_ordered + + def symmetrize_edges( + self, + graph, + batch_idx, + ): + """ + Symmetrize edges to ensure existence of counter-directional edges. + + Some edges are only present in one direction in the data, + since every atom has a maximum number of neighbors. + We only use i->j edges here. So we lose some j->i edges + and add others by making it symmetric. + """ + num_atoms = batch_idx.shape[0] + new_graph = {} + + # Generate mask + mask_sep_atoms = graph["edge_index"][0] < graph["edge_index"][1] + # Distinguish edges between the same (periodic) atom by ordering the cells + cell_earlier = ( + (graph["cell_offset"][:, 0] < 0) + | ((graph["cell_offset"][:, 0] == 0) & (graph["cell_offset"][:, 1] < 0)) + | ( + (graph["cell_offset"][:, 0] == 0) + & (graph["cell_offset"][:, 1] == 0) + & (graph["cell_offset"][:, 2] < 0) + ) + ) + mask_same_atoms = graph["edge_index"][0] == graph["edge_index"][1] + mask_same_atoms &= cell_earlier + mask = mask_sep_atoms | mask_same_atoms + + # Mask out counter-edges + edge_index_directed = graph["edge_index"][mask[None, :].expand(2, -1)].view( + 2, -1 + ) + + # Concatenate counter-edges after normal edges + edge_index_cat = torch.cat( + [edge_index_directed, edge_index_directed.flip(0)], + dim=1, + ) + + # Count remaining edges per image + batch_edge = torch.repeat_interleave( + torch.arange( + graph["num_neighbors"].size(0), + device=graph["edge_index"].device, + ), + graph["num_neighbors"], + ) + batch_edge = batch_edge[mask] + # segment_coo assumes sorted batch_edge + # Factor 2 since this is only one half of the edges + ones = batch_edge.new_ones(1).expand_as(batch_edge) + new_graph["num_neighbors"] = 2 * segment_coo( + ones, batch_edge, dim_size=graph["num_neighbors"].size(0) + ) + + # Create indexing array + edge_reorder_idx = repeat_blocks( + torch.div(new_graph["num_neighbors"], 2, rounding_mode="floor"), + repeats=2, + continuous_indexing=True, + repeat_inc=edge_index_directed.size(1), + ) + + # Reorder everything so the edges of every image are consecutive + new_graph["edge_index"] = edge_index_cat[:, edge_reorder_idx] + new_graph["cell_offset"] = self.select_symmetric_edges( + graph["cell_offset"], mask, edge_reorder_idx, True + ) + new_graph["distance"] = self.select_symmetric_edges( + graph["distance"], mask, edge_reorder_idx, False + ) + new_graph["vector"] = self.select_symmetric_edges( + graph["vector"], mask, edge_reorder_idx, True + ) + + # Indices for swapping c->a and a->c (for symmetric MP) + # To obtain these efficiently and without any index assumptions, + # we get order the counter-edge IDs and then + # map this order back to the edge IDs. + # Double argsort gives the desired mapping + # from the ordered tensor to the original tensor. + edge_ids = get_edge_id( + new_graph["edge_index"], new_graph["cell_offset"], num_atoms + ) + order_edge_ids = torch.argsort(edge_ids) + inv_order_edge_ids = torch.argsort(order_edge_ids) + edge_ids_counter = get_edge_id( + new_graph["edge_index"].flip(0), + -new_graph["cell_offset"], + num_atoms, + ) + order_edge_ids_counter = torch.argsort(edge_ids_counter) + id_swap = order_edge_ids_counter[inv_order_edge_ids] + + return new_graph, id_swap + + def get_graphs_and_indices(self, data): + """ "Generate embedding and interaction graphs and indices.""" + num_atoms = data.atomic_numbers.size(0) + assert ( + self.atom_edge_interaction + and self.edge_atom_interaction + and self.atom_interaction + and self.quad_interaction + ), "Only the full interaction graph (ae + ea + a + q) is supported." + + graphs = graphs_from_batch(data) + a2a_graph = graphs["a2a"] + a2ee2a_graph = graphs["a2ee2a"] + main_graph = graphs["main"] + qint_graph = graphs["qint"] + + # Symmetrize edges for swapping in symmetric message passing + if True: + main_graph, id_swap = self.symmetrize_edges(main_graph, data.batch) + else: + raise NotImplementedError + id_swap = main_graph.get("id_swap_edge_index", None) + if id_swap is None: + raise ValueError( + "Expected id_swap in main_graph for symmetric MP, but it was not found." + ) + + trip_idx_e2e = get_triplets(main_graph, num_atoms=num_atoms) + + # Additional indices for quadruplets + if self.quad_interaction: + quad_idx = get_quadruplets( + main_graph, + qint_graph, + num_atoms, + ) + else: + quad_idx = {} + + if self.atom_edge_interaction: + trip_idx_a2e = get_mixed_triplets( + a2ee2a_graph, + main_graph, + num_atoms=num_atoms, + return_agg_idx=True, + ) + else: + trip_idx_a2e = {} + if self.edge_atom_interaction: + trip_idx_e2a = get_mixed_triplets( + main_graph, + a2ee2a_graph, + num_atoms=num_atoms, + return_agg_idx=True, + ) + # a2ee2a_graph['edge_index'][1] has to be sorted for this + a2ee2a_graph["target_neighbor_idx"] = get_inner_idx( + a2ee2a_graph["edge_index"][1], dim_size=num_atoms + ) + else: + trip_idx_e2a = {} + if self.atom_interaction: + # a2a_graph['edge_index'][1] has to be sorted for this + a2a_graph["target_neighbor_idx"] = get_inner_idx( + a2a_graph["edge_index"][1], dim_size=num_atoms + ) + + return ( + main_graph, + a2a_graph, + a2ee2a_graph, + qint_graph, + id_swap, + trip_idx_e2e, + trip_idx_a2e, + trip_idx_e2a, + quad_idx, + ) + + def forward( + self, + data: BaseData, + *, + h: torch.Tensor, + ): + pos = data.pos + # batch = data.batch + # atomic_numbers = data.atomic_numbers.long() + num_atoms = data.atomic_numbers.shape[0] + + if self.regress_forces and not self.direct_forces: + pos.requires_grad_(True) + + ( + main_graph, + a2a_graph, + a2ee2a_graph, + qint_graph, + id_swap, + trip_idx_e2e, + trip_idx_a2e, + trip_idx_e2a, + quad_idx, + ) = self.get_graphs_and_indices(data) + idx_s, idx_t = main_graph["edge_index"] + + bases: BasesOutput = self.bases( + data, + h=h, + main_graph=main_graph, + a2a_graph=a2a_graph, + a2ee2a_graph=a2ee2a_graph, + qint_graph=qint_graph, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + num_atoms=num_atoms, + ) + m = bases.m + + # Embedding block + # h = self.atom_emb(atomic_numbers) + # (nAtoms, emb_size_atom) + # m = self.edge_emb(h, bases.rbf_main, main_graph["edge_index"]) + # (nEdges_main, emb_size_edge) + + x_E, x_F = self.out_blocks[0](h, m, bases.output, idx_t, data=data) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E, xs_F = [x_E], [x_F] + + for i in range(self.num_blocks): + if self.config.unique_basis_per_layer: + bases: BasesOutput = self.per_layer_bases[i]( + data, + h=h, + main_graph=main_graph, + a2a_graph=a2a_graph, + a2ee2a_graph=a2ee2a_graph, + qint_graph=qint_graph, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + num_atoms=num_atoms, + ) + m = m + bases.m + + # Interaction block + h, m = self.int_blocks[i]( + h=h, + m=m, + bases_qint=bases.qint, + bases_e2e=bases.e2e, + bases_a2e=bases.a2e, + bases_e2a=bases.e2a, + basis_a2a_rad=bases.a2a_rad, + basis_atom_update=bases.atom_update, + edge_index_main=main_graph["edge_index"], + a2ee2a_graph=a2ee2a_graph, + a2a_graph=a2a_graph, + id_swap=id_swap, + trip_idx_e2e=trip_idx_e2e, + trip_idx_a2e=trip_idx_a2e, + trip_idx_e2a=trip_idx_e2a, + quad_idx=quad_idx, + ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + + x_E, x_F = self.out_blocks[i + 1](h, m, bases.output, idx_t, data=data) + # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) + xs_E.append(x_E) + xs_F.append(x_F) + + # Global output block for final predictions + if self.regress_forces: + assert self.direct_forces, "Only direct forces are supported for now." + x_F = self.out_mlp_F( + torch.cat(xs_F, dim=-1), data=data, edge_index=main_graph["edge_index"] + ) + else: + x_F = None + + if self.regress_energy: + x_E = self.out_mlp_E( + torch.cat(xs_E, dim=-1), data=data, edge_index=main_graph["edge_index"] + ) + else: + x_E = None + + out: GOCBackboneOutput = { + "energy": x_E, + "forces": x_F, + "V_st": main_graph["vector"], + "D_st": main_graph["distance"], + "idx_s": idx_s, + "idx_t": idx_t, + } + return out diff --git a/src/jmp/models/gemnet/bases.py b/src/jmp/models/gemnet/bases.py new file mode 100644 index 0000000..ade043e --- /dev/null +++ b/src/jmp/models/gemnet/bases.py @@ -0,0 +1,547 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from dataclasses import dataclass +from typing import TypedDict + +import torch +import torch.nn as nn +from torch_geometric.data import Batch +from torch_sparse import SparseTensor + +from .config import BasesConfig +from .layers.base_layers import Dense +from .layers.efficient import BasisEmbedding +from .layers.embedding_block import EdgeEmbedding +from .layers.radial_basis_dynamic_cutoff import GaussianBasis, RadialBasis +from .layers.spherical_basis_dynamic_cutoff import ( + CircularBasisLayer, + SphericalBasisLayer, +) +from .utils import get_angle, inner_product_clamped + +TripletIn = TypedDict( + "TripletIn", {"adj_edges": SparseTensor, "in": torch.Tensor, "out": torch.Tensor} +) +TripletOut = TypedDict("TripletOut", {"in": torch.Tensor, "out": torch.Tensor}) + + +class QuadIdx(TypedDict): + triplet_in: TripletIn + triplet_out: TripletOut + out: torch.Tensor + trip_out_to_quad: torch.Tensor + trip_in_to_quad: torch.Tensor + out_agg: torch.Tensor + + +class GraphBases(TypedDict): + rad: torch.Tensor + cir: list[torch.Tensor] + # cir: tuple[torch.Tensor, torch.Tensor] + + +class GraphBasesQInt(TypedDict): + rad: torch.Tensor + cir: torch.Tensor + sph: list[torch.Tensor] + # sph: tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class BasesOutput: + m: torch.Tensor # (e_main, emb_size_edge) + atom_update: torch.Tensor # (e_main, emb_size_rbf) + output: torch.Tensor # (e_main, emb_size_rbf) + qint: GraphBasesQInt + e2e: GraphBases + a2e: GraphBases + e2a: GraphBases + a2a_rad: torch.Tensor | None + + +class Bases(nn.Module): + def __init__( + self, + config: BasesConfig, + dropout: float | None = None, + ): + super().__init__() + + self.config = config + self.dropout = dropout + + self.init_basis_functions() + self.init_shared_basis_layers() + + self.edge_emb = EdgeEmbedding( + self.config.emb_size_atom, + self.config.num_radial, + self.config.emb_size_edge, + activation=self.config.activation, + dropout=self.dropout, + ) + + if not self.config.unique_per_layer: + self._set_shared_params() + + def _set_shared_params(self): + # Set shared parameters for better gradients + self.shared_parameters: list[tuple[nn.Parameter, float | int]] = [] + self.shared_parameters += [ + (self.mlp_rbf_tint.linear.weight, self.config.num_blocks), + (self.mlp_cbf_tint.weight, self.config.num_blocks), + (self.mlp_rbf_h.linear.weight, self.config.num_blocks), + (self.mlp_rbf_out.linear.weight, self.config.num_blocks + 1), + ] + if self.config.quad_interaction: + self.shared_parameters += [ + (self.mlp_rbf_qint.linear.weight, self.config.num_blocks), + (self.mlp_cbf_qint.weight, self.config.num_blocks), + (self.mlp_sbf_qint.weight, self.config.num_blocks), + ] + if self.config.atom_edge_interaction: + self.shared_parameters += [ + (self.mlp_rbf_aeint.linear.weight, self.config.num_blocks), + (self.mlp_cbf_aeint.weight, self.config.num_blocks), + ] + if self.config.edge_atom_interaction: + self.shared_parameters += [ + (self.mlp_rbf_eaint.linear.weight, self.config.num_blocks), + (self.mlp_cbf_eaint.weight, self.config.num_blocks), + ] + if self.config.atom_interaction: + self.shared_parameters += [ + (self.mlp_rbf_aint.weight, self.config.num_blocks), + ] + + self._add_rbf_shared_params(self.radial_basis) + if self.config.quad_interaction: + self._add_rbf_shared_params(self.cbf_basis_qint.radial_basis) + self._add_rbf_shared_params(self.sbf_basis_qint.radial_basis) + + if self.config.atom_edge_interaction: + self._add_rbf_shared_params(self.cbf_basis_aeint.radial_basis) + if self.config.edge_atom_interaction: + self._add_rbf_shared_params(self.cbf_basis_eaint.radial_basis) + if self.config.atom_interaction: + self._add_rbf_shared_params(self.radial_basis_aint) + self._add_rbf_shared_params(self.cbf_basis_tint.radial_basis) + + def _add_rbf_shared_params(self, radial: RadialBasis): + match radial.rbf: + case GaussianBasis(trainable=True) as gbf: + for param in gbf.parameters(): + self._add_shared_param(param, self.config.num_blocks) + case _: + pass + + def _add_shared_param(self, param: nn.Parameter, factor: int | float): + if ( + shared_param_idx := next( + (i for i, p in enumerate(self.shared_parameters) if p[0] is param), + None, + ) + ) is not None: + self.shared_parameters[shared_param_idx] = ( + param, + self.shared_parameters[shared_param_idx][1] + factor, + ) + else: + self.shared_parameters += [(param, factor)] + + def init_basis_functions(self): + self.radial_basis = RadialBasis( + num_radial=self.config.num_radial, + graph_type="main", + rbf=self.config.rbf, + envelope=self.config.envelope, + scale_basis=self.config.scale_basis, + absolute_cutoff=self.config.absolute_rbf_cutoff, + ) + radial_basis_spherical = RadialBasis( + num_radial=self.config.num_radial, + graph_type="main", + rbf=self.config.rbf_spherical, + envelope=self.config.envelope, + scale_basis=self.config.scale_basis, + absolute_cutoff=self.config.absolute_rbf_cutoff, + ) + if self.config.quad_interaction: + radial_basis_spherical_qint = RadialBasis( + num_radial=self.config.num_radial, + graph_type="qint", + rbf=self.config.rbf_spherical, + envelope=self.config.envelope, + scale_basis=self.config.scale_basis, + absolute_cutoff=self.config.absolute_rbf_cutoff, + ) + self.cbf_basis_qint = CircularBasisLayer( + self.config.num_spherical, + radial_basis=radial_basis_spherical_qint, + cbf=self.config.cbf, + scale_basis=self.config.scale_basis, + ) + + self.sbf_basis_qint = SphericalBasisLayer( + self.config.num_spherical, + radial_basis=radial_basis_spherical, + sbf=self.config.sbf, + scale_basis=self.config.scale_basis, + ) + if self.config.atom_edge_interaction: + self.radial_basis_aeaint = RadialBasis( + num_radial=self.config.num_radial, + graph_type="a2ee2a", + rbf=self.config.rbf, + envelope=self.config.envelope, + scale_basis=self.config.scale_basis, + absolute_cutoff=self.config.absolute_rbf_cutoff, + ) + self.cbf_basis_aeint = CircularBasisLayer( + self.config.num_spherical, + radial_basis=radial_basis_spherical, + cbf=self.config.cbf, + scale_basis=self.config.scale_basis, + ) + if self.config.edge_atom_interaction: + self.radial_basis_aeaint = RadialBasis( + num_radial=self.config.num_radial, + graph_type="a2ee2a", + rbf=self.config.rbf, + envelope=self.config.envelope, + scale_basis=self.config.scale_basis, + absolute_cutoff=self.config.absolute_rbf_cutoff, + ) + radial_basis_spherical_aeaint = RadialBasis( + num_radial=self.config.num_radial, + graph_type="a2ee2a", + rbf=self.config.rbf_spherical, + envelope=self.config.envelope, + scale_basis=self.config.scale_basis, + absolute_cutoff=self.config.absolute_rbf_cutoff, + ) + self.cbf_basis_eaint = CircularBasisLayer( + self.config.num_spherical, + radial_basis=radial_basis_spherical_aeaint, + cbf=self.config.cbf, + scale_basis=self.config.scale_basis, + ) + if self.config.atom_interaction: + self.radial_basis_aint = RadialBasis( + num_radial=self.config.num_radial, + graph_type="a2a", + rbf=self.config.rbf, + envelope=self.config.envelope, + scale_basis=self.config.scale_basis, + absolute_cutoff=self.config.absolute_rbf_cutoff, + ) + + self.cbf_basis_tint = CircularBasisLayer( + self.config.num_spherical, + radial_basis=radial_basis_spherical, + cbf=self.config.cbf, + scale_basis=self.config.scale_basis, + ) + + def init_shared_basis_layers(self): + # Share basis down projections across all interaction blocks + if self.config.quad_interaction: + self.mlp_rbf_qint = Dense( + self.config.num_radial, + self.config.emb_size_rbf, + activation=None, + bias=False, + dropout=self.dropout, + ) + self.mlp_cbf_qint = BasisEmbedding( + self.config.num_radial, + self.config.emb_size_cbf, + self.config.num_spherical, + ) + self.mlp_sbf_qint = BasisEmbedding( + self.config.num_radial, + self.config.emb_size_sbf, + self.config.num_spherical**2, + ) + + if self.config.atom_edge_interaction: + self.mlp_rbf_aeint = Dense( + self.config.num_radial, + self.config.emb_size_rbf, + activation=None, + bias=False, + dropout=self.dropout, + ) + self.mlp_cbf_aeint = BasisEmbedding( + self.config.num_radial, + self.config.emb_size_cbf, + self.config.num_spherical, + ) + if self.config.edge_atom_interaction: + self.mlp_rbf_eaint = Dense( + self.config.num_radial, + self.config.emb_size_rbf, + activation=None, + bias=False, + dropout=self.dropout, + ) + self.mlp_cbf_eaint = BasisEmbedding( + self.config.num_radial, + self.config.emb_size_cbf, + self.config.num_spherical, + ) + if self.config.atom_interaction: + self.mlp_rbf_aint = BasisEmbedding( + self.config.num_radial, self.config.emb_size_rbf + ) + + self.mlp_rbf_tint = Dense( + self.config.num_radial, + self.config.emb_size_rbf, + activation=None, + bias=False, + dropout=self.dropout, + ) + self.mlp_cbf_tint = BasisEmbedding( + self.config.num_radial, self.config.emb_size_cbf, self.config.num_spherical + ) + + # Share the dense Layer of the atom embedding block accross the interaction blocks + self.mlp_rbf_h = Dense( + self.config.num_radial, + self.config.emb_size_rbf, + activation=None, + bias=False, + dropout=self.dropout, + ) + self.mlp_rbf_out = Dense( + self.config.num_radial, + self.config.emb_size_rbf, + activation=None, + bias=False, + dropout=self.dropout, + ) + + def calculate_quad_angles( + self, + V_st: torch.Tensor, + V_qint_st: torch.Tensor, + quad_idx: dict, + ): + """Calculate angles for quadruplet-based message passing. + + Arguments + --------- + V_st: Tensor, shape = (nAtoms, 3) + Normalized directions from s to t + V_qint_st: Tensor, shape = (nAtoms, 3) + Normalized directions from s to t for the quadruplet + interaction graph + quad_idx: dict of torch.Tensor + Indices relevant for quadruplet interactions. + + Returns + ------- + cosφ_cab: Tensor, shape = (num_triplets_inint,) + Cosine of angle between atoms c -> a <- b. + cosφ_abd: Tensor, shape = (num_triplets_qint,) + Cosine of angle between atoms a -> b -> d. + angle_cabd: Tensor, shape = (num_quadruplets,) + Dihedral angle between atoms c <- a-b -> d. + """ + # ---------------------------------- d -> b -> a ---------------------------------- # + V_ba = V_qint_st[quad_idx["triplet_in"]["out"]] + # (num_triplets_qint, 3) + V_db = V_st[quad_idx["triplet_in"]["in"]] + # (num_triplets_qint, 3) + cosφ_abd = inner_product_clamped(V_ba, V_db) + # (num_triplets_qint,) + + # Project for calculating dihedral angle + # Cross product is the same as projection, just 90° rotated + V_db_cross = torch.cross(V_db, V_ba, dim=-1) # a - b -| d + V_db_cross = V_db_cross[quad_idx["trip_in_to_quad"]] + # (num_quadruplets,) + + # --------------------------------- c -> a <- b ---------------------------------- # + V_ca = V_st[quad_idx["triplet_out"]["out"]] # (num_triplets_in, 3) + V_ba = V_qint_st[quad_idx["triplet_out"]["in"]] # (num_triplets_in, 3) + cosφ_cab = inner_product_clamped(V_ca, V_ba) # (n4Triplets,) + + # Project for calculating dihedral angle + # Cross product is the same as projection, just 90° rotated + V_ca_cross = torch.cross(V_ca, V_ba, dim=-1) # c |- a - b + V_ca_cross = V_ca_cross[quad_idx["trip_out_to_quad"]] + # (num_quadruplets,) + + # -------------------------------- c -> a - b <- d -------------------------------- # + half_angle_cabd = get_angle(V_ca_cross, V_db_cross) + # (num_quadruplets,) + angle_cabd = half_angle_cabd + # Ignore parity and just use the half angle. + + return cosφ_cab, cosφ_abd, angle_cabd + + def forward( + self, + data: Batch, + *, + h: torch.Tensor, + main_graph: dict, + a2a_graph: dict, + a2ee2a_graph: dict, + qint_graph: dict, + trip_idx_e2e: dict, + trip_idx_a2e: dict, + trip_idx_e2a: dict, + quad_idx: dict, + num_atoms: int, + ): + """Calculate and transform basis functions.""" + basis_rad_main_raw = self.radial_basis(main_graph["distance"], data=data) + + # Calculate triplet angles + cosφ_cab = inner_product_clamped( + main_graph["vector"][trip_idx_e2e["out"]], + main_graph["vector"][trip_idx_e2e["in"]], + ) + basis_rad_cir_e2e_raw, basis_cir_e2e_raw = self.cbf_basis_tint( + main_graph["distance"], + cosφ_cab, + data=data, + ) + + if self.config.quad_interaction: + # Calculate quadruplet angles + cosφ_cab_q, cosφ_abd, angle_cabd = self.calculate_quad_angles( + main_graph["vector"], + qint_graph["vector"], + quad_idx, + ) + + basis_rad_cir_qint_raw, basis_cir_qint_raw = self.cbf_basis_qint( + qint_graph["distance"], + cosφ_abd, + data=data, + ) + basis_rad_sph_qint_raw, basis_sph_qint_raw = self.sbf_basis_qint( + main_graph["distance"], + cosφ_cab_q[quad_idx["trip_out_to_quad"]], + angle_cabd, + data=data, + ) + if self.config.atom_edge_interaction: + basis_rad_a2ee2a_raw = self.radial_basis_aeaint( + a2ee2a_graph["distance"], + data=data, + ) + cosφ_cab_a2e = inner_product_clamped( + main_graph["vector"][trip_idx_a2e["out"]], + a2ee2a_graph["vector"][trip_idx_a2e["in"]], + ) + basis_rad_cir_a2e_raw, basis_cir_a2e_raw = self.cbf_basis_aeint( + main_graph["distance"], + cosφ_cab_a2e, + data=data, + ) + if self.config.edge_atom_interaction: + cosφ_cab_e2a = inner_product_clamped( + a2ee2a_graph["vector"][trip_idx_e2a["out"]], + main_graph["vector"][trip_idx_e2a["in"]], + ) + basis_rad_cir_e2a_raw, basis_cir_e2a_raw = self.cbf_basis_eaint( + a2ee2a_graph["distance"], + cosφ_cab_e2a, + data=data, + ) + if self.config.atom_interaction: + basis_rad_a2a_raw = self.radial_basis_aint( + a2a_graph["distance"], + data=data, + ) + + # Shared Down Projections + bases_qint: GraphBasesQInt = {} + if self.config.quad_interaction: + bases_qint["rad"] = self.mlp_rbf_qint(basis_rad_main_raw) + bases_qint["cir"] = self.mlp_cbf_qint( + rad_basis=basis_rad_cir_qint_raw, + sph_basis=basis_cir_qint_raw, + idx_sph_outer=quad_idx["triplet_in"]["out"], + ) + bases_qint["sph"] = list( + self.mlp_sbf_qint( + rad_basis=basis_rad_sph_qint_raw, + sph_basis=basis_sph_qint_raw, + idx_sph_outer=quad_idx["out"], + idx_sph_inner=quad_idx["out_agg"], + ) + ) + + bases_a2e: GraphBases = {} + if self.config.atom_edge_interaction: + bases_a2e["rad"] = self.mlp_rbf_aeint(basis_rad_a2ee2a_raw) + bases_a2e["cir"] = list( + self.mlp_cbf_aeint( + rad_basis=basis_rad_cir_a2e_raw, + sph_basis=basis_cir_a2e_raw, + idx_sph_outer=trip_idx_a2e["out"], + idx_sph_inner=trip_idx_a2e["out_agg"], + ) + ) + bases_e2a: GraphBases = {} + if self.config.edge_atom_interaction: + bases_e2a["rad"] = self.mlp_rbf_eaint(basis_rad_main_raw) + bases_e2a["cir"] = list( + self.mlp_cbf_eaint( + rad_basis=basis_rad_cir_e2a_raw, + sph_basis=basis_cir_e2a_raw, + idx_rad_outer=a2ee2a_graph["edge_index"][1], + idx_rad_inner=a2ee2a_graph["target_neighbor_idx"], + idx_sph_outer=trip_idx_e2a["out"], + idx_sph_inner=trip_idx_e2a["out_agg"], + num_atoms=num_atoms, + ) + ) + if self.config.atom_interaction: + basis_a2a_rad = self.mlp_rbf_aint( + rad_basis=basis_rad_a2a_raw, + idx_rad_outer=a2a_graph["edge_index"][1], + idx_rad_inner=a2a_graph["target_neighbor_idx"], + num_atoms=num_atoms, + ) + else: + basis_a2a_rad = None + + bases_e2e: GraphBases = {} + bases_e2e["rad"] = self.mlp_rbf_tint(basis_rad_main_raw) + bases_e2e["cir"] = list( + self.mlp_cbf_tint( + rad_basis=basis_rad_cir_e2e_raw, + sph_basis=basis_cir_e2e_raw, + idx_sph_outer=trip_idx_e2e["out"], + idx_sph_inner=trip_idx_e2e["out_agg"], + ) + ) + + basis_atom_update = self.mlp_rbf_h(basis_rad_main_raw) + basis_output = self.mlp_rbf_out(basis_rad_main_raw) + + m = self.edge_emb(h, basis_rad_main_raw, main_graph["edge_index"]) + + return BasesOutput( + m, # (e_main, emb_size_edge) + basis_atom_update, # (e_main, emb_size_rbf) + basis_output, # (e_main, emb_size_rbf) + bases_qint, # rad=(e_main, emb_size_rbf), cir=(num_triplets_qint, emb_size_cbf), sph=[(e_main, emb_size_sbf, num_spherical**2), (e_main, num_spherical**2, n_atoms)] + bases_e2e, # rad=(e_main, emb_size_rbf), cir=[(e_main, emb_size_sbf, num_spherical), (e_main, num_spherical, Kmax_e2e)] + bases_a2e, # rad=(e_a2ee2a, emb_size_rbf), cir=[(e_main, emb_size_sbf, num_spherical), (e_main, num_spherical, Kmax_a2e)] + bases_e2a, # rad=(e_main, emb_size_rbf), cir=[(n_atoms, emb_size_sbf, emb_size_interim), (e_a2ee2a, num_spherical, Kmax_e2e)] + basis_a2a_rad, # (n_atoms, emb_size_rbf, emb_size_interm) + ) diff --git a/src/jmp/models/gemnet/config.py b/src/jmp/models/gemnet/config.py new file mode 100644 index 0000000..a08713c --- /dev/null +++ b/src/jmp/models/gemnet/config.py @@ -0,0 +1,311 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import torch +from jmp.lightning import TypedConfig +from typing_extensions import override + + +class BackboneConfig(TypedConfig): + num_targets: int = 1 + num_spherical: int + num_radial: int + num_blocks: int + emb_size_atom: int + emb_size_edge: int + emb_size_trip_in: int + emb_size_trip_out: int + emb_size_quad_in: int + emb_size_quad_out: int + emb_size_aint_in: int + emb_size_aint_out: int + emb_size_rbf: int + emb_size_cbf: int + emb_size_sbf: int + num_before_skip: int + num_after_skip: int + num_concat: int + num_atom: int + num_output_afteratom: int + num_atom_emb_layers: int = 0 + num_global_out_layers: int = 2 + regress_forces: bool = True + regress_energy: bool = True + direct_forces: bool = False + use_pbc: bool = True + scale_backprop_forces: bool = False + rbf: dict = {"name": "gaussian"} + rbf_spherical: dict | None = None + envelope: dict = {"name": "polynomial", "exponent": 5} + cbf: dict = {"name": "spherical_harmonics"} + sbf: dict = {"name": "spherical_harmonics"} + extensive: bool = True + forces_coupled: bool = False + activation: str = "scaled_silu" + quad_interaction: bool = False + atom_edge_interaction: bool = False + edge_atom_interaction: bool = False + atom_interaction: bool = False + scale_basis: bool = False + qint_tags: list = [0, 1, 2] + num_elements: int = 120 + otf_graph: bool = False + scale_file: str | None = None + + absolute_rbf_cutoff: float | None = None + learnable_rbf: bool = False + learnable_rbf_stds: bool = False + + unique_basis_per_layer: bool = False + + dropout: float | None + edge_dropout: float | None + + @classmethod + def base(cls): + return cls( + num_targets=1, + num_spherical=7, + num_radial=128, + num_blocks=4, + emb_size_atom=256, + emb_size_edge=512, + emb_size_trip_in=64, + emb_size_trip_out=64, + emb_size_quad_in=32, + emb_size_quad_out=32, + emb_size_aint_in=64, + emb_size_aint_out=64, + emb_size_rbf=16, + emb_size_cbf=16, + emb_size_sbf=32, + num_before_skip=2, + num_after_skip=2, + num_concat=1, + num_atom=3, + num_output_afteratom=3, + num_atom_emb_layers=2, + num_global_out_layers=2, + regress_forces=True, + regress_energy=True, + direct_forces=True, + use_pbc=True, + scale_backprop_forces=False, + rbf={"name": "gaussian"}, + rbf_spherical=None, + envelope={"name": "polynomial", "exponent": 5}, + cbf={"name": "spherical_harmonics"}, + sbf={"name": "legendre_outer"}, + extensive=True, + forces_coupled=False, + activation="scaled_silu", + quad_interaction=True, + atom_edge_interaction=True, + edge_atom_interaction=True, + atom_interaction=True, + scale_basis=False, + qint_tags=[1, 2], + num_elements=120, + otf_graph=False, + scale_file=None, + absolute_rbf_cutoff=12.0, + dropout=None, + edge_dropout=None, + ) + + @classmethod + def download_base(cls, **kwargs): + """Load GemNetOCBackbone config from github""" + import requests + import yaml + + # read config from url + response = requests.get( + "https://raw.githubusercontent.com/Open-Catalyst-Project/ocp/main/configs/s2ef/all/gemnet/gemnet-oc.yml" + ) + config_dict = yaml.safe_load(response.text) + + model_config: dict = {**config_dict["model"]} + _ = model_config.pop("name", None) + _ = model_config.pop("scale_file", None) + + for key in list(model_config.keys()): + if any([key.startswith(prefix) for prefix in ["cutoff", "max_neighbors"]]): + _ = model_config.pop(key) + + model_config.update(kwargs) + config = cls.from_dict(model_config) + return config + + @classmethod + def large(cls): + return cls( + **{ + "num_targets": 1, + "num_spherical": 7, + "num_radial": 128, + "num_blocks": 6, + "emb_size_atom": 256, + "emb_size_edge": 1024, + "emb_size_trip_in": 64, + "emb_size_trip_out": 128, + "emb_size_quad_in": 64, + "emb_size_quad_out": 32, + "emb_size_aint_in": 64, + "emb_size_aint_out": 64, + "emb_size_rbf": 32, + "emb_size_cbf": 16, + "emb_size_sbf": 64, + "num_before_skip": 2, + "num_after_skip": 2, + "num_concat": 4, + "num_atom": 3, + "num_output_afteratom": 3, + "num_atom_emb_layers": 2, + "num_global_out_layers": 2, + "regress_forces": True, + "regress_energy": True, + "direct_forces": True, + "use_pbc": True, + "scale_backprop_forces": False, + "rbf": {"name": "gaussian"}, + "rbf_spherical": None, + "envelope": {"name": "polynomial", "exponent": 5}, + "cbf": {"name": "spherical_harmonics"}, + "sbf": {"name": "legendre_outer"}, + "extensive": True, + "forces_coupled": False, + "activation": "scaled_silu", + "quad_interaction": True, + "atom_edge_interaction": True, + "edge_atom_interaction": True, + "atom_interaction": True, + "scale_basis": False, + "qint_tags": [1, 2], + "num_elements": 120, + "otf_graph": False, + "scale_file": None, + "learnable_rbf": False, + "learnable_rbf_stds": False, + "unique_basis_per_layer": False, + }, + absolute_rbf_cutoff=12.0, + dropout=None, + edge_dropout=None, + ) + + @classmethod + def download_large(cls, **kwargs): + """Load GemNetOCBackbone config from github""" + import requests + import yaml + + # read config from url + response = requests.get( + "https://raw.githubusercontent.com/Open-Catalyst-Project/ocp/main/configs/s2ef/all/gemnet/gemnet-oc-large.yml" + ) + config_dict = yaml.safe_load(response.text) + + model_config: dict = {**config_dict["model"]} + _ = model_config.pop("name", None) + _ = model_config.pop("scale_file", None) + + for key in list(model_config.keys()): + if any([key.startswith(prefix) for prefix in ["cutoff", "max_neighbors"]]): + _ = model_config.pop(key) + + model_config.update(kwargs) + config = cls.from_dict(model_config) + return config + + @classmethod + def from_ckpt( + cls, + ckpt_path: Path | str, + transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + ): + ckpt = torch.load(ckpt_path, map_location="cpu") + config = ckpt["config"] + if transform is not None: + config = transform(config) + return cls(**config) + + +class BasesConfig(TypedConfig): + emb_size_rbf: int + emb_size_cbf: int + emb_size_sbf: int + num_spherical: int + num_radial: int + rbf: dict = {"name": "gaussian"} + rbf_spherical: dict | None = None + envelope: dict = {"name": "polynomial", "exponent": 5} + cbf: dict = {"name": "spherical_harmonics"} + sbf: dict = {"name": "spherical_harmonics"} + scale_basis: bool = False + absolute_rbf_cutoff: float | None = None + + num_blocks: int + quad_interaction: bool + atom_edge_interaction: bool + edge_atom_interaction: bool + atom_interaction: bool + + emb_size_atom: int + emb_size_edge: int + activation: str + + learnable: bool = False + learnable_rbf_stds: bool = False + + unique_per_layer: bool = False + + @classmethod + def from_backbone_config(cls, backbone_config: BackboneConfig): + return cls( + emb_size_rbf=backbone_config.emb_size_rbf, + emb_size_cbf=backbone_config.emb_size_cbf, + emb_size_sbf=backbone_config.emb_size_sbf, + num_spherical=backbone_config.num_spherical, + num_radial=backbone_config.num_radial, + rbf=backbone_config.rbf, + rbf_spherical=backbone_config.rbf_spherical, + envelope=backbone_config.envelope, + cbf=backbone_config.cbf, + sbf=backbone_config.sbf, + scale_basis=backbone_config.scale_basis, + absolute_rbf_cutoff=backbone_config.absolute_rbf_cutoff, + num_blocks=backbone_config.num_blocks, + quad_interaction=backbone_config.quad_interaction, + atom_edge_interaction=backbone_config.atom_edge_interaction, + edge_atom_interaction=backbone_config.edge_atom_interaction, + atom_interaction=backbone_config.atom_interaction, + emb_size_atom=backbone_config.emb_size_atom, + emb_size_edge=backbone_config.emb_size_edge, + activation=backbone_config.activation, + learnable=backbone_config.learnable_rbf, + learnable_rbf_stds=backbone_config.learnable_rbf_stds, + unique_per_layer=backbone_config.unique_basis_per_layer, + ) + + @override + def __post_init__(self): + if not self.rbf_spherical: + self.rbf_spherical = self.rbf.copy() + + if self.learnable: + self.rbf["trainable"] = True + self.rbf_spherical["trainable"] = True + + if self.learnable_rbf_stds: + self.rbf["trainable_stds"] = True + self.rbf_spherical["trainable_stds"] = True diff --git a/src/jmp/models/gemnet/initializers.py b/src/jmp/models/gemnet/initializers.py new file mode 100644 index 0000000..bbaf94b --- /dev/null +++ b/src/jmp/models/gemnet/initializers.py @@ -0,0 +1,97 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import torch + + +def _standardize(kernel): + """ + Makes sure that N*Var(W) = 1 and E[W] = 0 + """ + eps = 1e-6 + + if len(kernel.shape) == 3: + axis = [0, 1] # last dimension is output dimension + else: + axis = 1 + + var, mean = torch.var_mean(kernel, dim=axis, unbiased=True, keepdim=True) + kernel = (kernel - mean) / (var + eps) ** 0.5 + return kernel + + +def he_orthogonal_init(tensor): + """ + Generate a weight matrix with variance according to He (Kaiming) initialization. + Based on a random (semi-)orthogonal matrix neural networks + are expected to learn better when features are decorrelated + (stated by eg. "Reducing overfitting in deep networks by decorrelating representations", + "Dropout: a simple way to prevent neural networks from overfitting", + "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks") + """ + tensor = torch.nn.init.orthogonal_(tensor) + + if len(tensor.shape) == 3: + fan_in = tensor.shape[:-1].numel() + else: + fan_in = tensor.shape[1] + + with torch.no_grad(): + tensor.data = _standardize(tensor.data) + tensor.data *= (1 / fan_in) ** 0.5 + + return tensor + + +def grid_init(tensor, start=-1, end=1): + """ + Generate a weight matrix so that each input value corresponds to one value on a regular grid between start and end. + """ + fan_in = tensor.shape[1] + + with torch.no_grad(): + data = torch.linspace( + start, end, fan_in, device=tensor.device, dtype=tensor.dtype + ).expand_as(tensor) + tensor.copy_(data) + + return tensor + + +def log_grid_init(tensor, start=-4, end=0): + """ + Generate a weight matrix so that each input value corresponds to one value on a regular logarithmic grid between 10^start and 10^end. + """ + fan_in = tensor.shape[1] + + with torch.no_grad(): + data = torch.logspace( + start, end, fan_in, device=tensor.device, dtype=tensor.dtype + ).expand_as(tensor) + tensor.copy_(data) + + return tensor + + +def get_initializer(name, **init_kwargs): + name = name.lower() + if name == "heorthogonal": + initializer = he_orthogonal_init + elif name == "zeros": + initializer = torch.nn.init.zeros_ + elif name == "grid": + initializer = grid_init + elif name == "loggrid": + initializer = log_grid_init + else: + raise UserWarning(f"Unknown initializer: {name}") + + initializer = partial(initializer, **init_kwargs) + return initializer diff --git a/src/jmp/models/gemnet/interaction_indices.py b/src/jmp/models/gemnet/interaction_indices.py new file mode 100644 index 0000000..c2d90d4 --- /dev/null +++ b/src/jmp/models/gemnet/interaction_indices.py @@ -0,0 +1,302 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch_scatter import segment_coo +from torch_sparse import SparseTensor + +from .utils import get_inner_idx, masked_select_sparsetensor_flat + + +def get_triplets(graph, num_atoms): + """ + Get all input edges b->a for each output edge c->a. + It is possible that b=c, as long as the edges are distinct + (i.e. atoms b and c stem from different unit cells). + + Arguments + --------- + graph: dict of torch.Tensor + Contains the graph's edge_index. + num_atoms: int + Total number of atoms. + + Returns + ------- + Dictionary containing the entries: + in: torch.Tensor, shape (num_triplets,) + Indices of input edge b->a of each triplet b->a<-c + out: torch.Tensor, shape (num_triplets,) + Indices of output edge c->a of each triplet b->a<-c + out_agg: torch.Tensor, shape (num_triplets,) + Indices enumerating the intermediate edges of each output edge. + Used for creating a padded matrix and aggregating via matmul. + """ + idx_s, idx_t = graph["edge_index"] # c->a (source=c, target=a) + num_edges = idx_s.size(0) + + value = torch.arange(num_edges, device=idx_s.device, dtype=idx_s.dtype) + # Possibly contains multiple copies of the same edge (for periodic interactions) + adj = SparseTensor( + row=idx_t, + col=idx_s, + value=value, + sparse_sizes=(num_atoms, num_atoms), + ) + adj_edges = adj[idx_t] + + # Edge indices (b->a, c->a) for triplets. + idx = {} + idx["in"] = adj_edges.storage.value() + idx["out"] = adj_edges.storage.row() + + # Remove self-loop triplets + # Compare edge indices, not atom indices to correctly handle periodic interactions + mask = idx["in"] != idx["out"] + idx["in"] = idx["in"][mask] + idx["out"] = idx["out"][mask] + + # idx['out'] has to be sorted for this + idx["out_agg"] = get_inner_idx(idx["out"], dim_size=num_edges) + + return idx + + +def get_mixed_triplets( + graph_in, + graph_out, + num_atoms, + to_outedge=False, + return_adj=False, + return_agg_idx=False, +): + """ + Get all output edges (ingoing or outgoing) for each incoming edge. + It is possible that in atom=out atom, as long as the edges are distinct + (i.e. they stem from different unit cells). In edges and out edges stem + from separate graphs (hence "mixed") with shared atoms. + + Arguments + --------- + graph_in: dict of torch.Tensor + Contains the input graph's edge_index and cell_offset. + graph_out: dict of torch.Tensor + Contains the output graph's edge_index and cell_offset. + Input and output graphs use the same atoms, but different edges. + num_atoms: int + Total number of atoms. + to_outedge: bool + Whether to map the output to the atom's outgoing edges a->c + instead of the ingoing edges c->a. + return_adj: bool + Whether to output the adjacency (incidence) matrix between output + edges and atoms adj_edges. + return_agg_idx: bool + Whether to output the indices enumerating the intermediate edges + of each output edge. + + Returns + ------- + Dictionary containing the entries: + in: torch.Tensor, shape (num_triplets,) + Indices of input edges + out: torch.Tensor, shape (num_triplets,) + Indices of output edges + adj_edges: SparseTensor, shape (num_edges, num_atoms) + Adjacency (incidence) matrix between output edges and atoms, + with values specifying the input edges. + Only returned if return_adj is True. + out_agg: torch.Tensor, shape (num_triplets,) + Indices enumerating the intermediate edges of each output edge. + Used for creating a padded matrix and aggregating via matmul. + Only returned if return_agg_idx is True. + """ + idx_out_s, idx_out_t = graph_out["edge_index"] + # c->a (source=c, target=a) + idx_in_s, idx_in_t = graph_in["edge_index"] + num_edges = idx_out_s.size(0) + + value_in = torch.arange( + idx_in_s.size(0), device=idx_in_s.device, dtype=idx_in_s.dtype + ) + # This exploits that SparseTensor can have multiple copies of the same edge! + adj_in = SparseTensor( + row=idx_in_t, + col=idx_in_s, + value=value_in, + sparse_sizes=(num_atoms, num_atoms), + ) + if to_outedge: + adj_edges = adj_in[idx_out_s] + else: + adj_edges = adj_in[idx_out_t] + + # Edge indices (b->a, c->a) for triplets. + idx_in = adj_edges.storage.value() + idx_out = adj_edges.storage.row() + + # Remove self-loop triplets c->a<-c or c<-a<-c + # Check atom as well as cell offset + if to_outedge: + idx_atom_in = idx_in_s[idx_in] + idx_atom_out = idx_out_t[idx_out] + cell_offsets_sum = ( + graph_out["cell_offset"][idx_out] + graph_in["cell_offset"][idx_in] + ) + else: + idx_atom_in = idx_in_s[idx_in] + idx_atom_out = idx_out_s[idx_out] + cell_offsets_sum = ( + graph_out["cell_offset"][idx_out] - graph_in["cell_offset"][idx_in] + ) + mask = (idx_atom_in != idx_atom_out) | torch.any(cell_offsets_sum != 0, dim=-1) + + idx = {} + if return_adj: + idx["adj_edges"] = masked_select_sparsetensor_flat(adj_edges, mask) + idx["in"] = idx["adj_edges"].storage.value().clone() + idx["out"] = idx["adj_edges"].storage.row() + else: + idx["in"] = idx_in[mask] + idx["out"] = idx_out[mask] + + if return_agg_idx: + # idx['out'] has to be sorted + idx["out_agg"] = get_inner_idx(idx["out"], dim_size=num_edges) + + return idx + + +def get_quadruplets( + main_graph, + qint_graph, + num_atoms, +): + """ + Get all d->b for each edge c->a and connection b->a + Careful about periodic images! + Separate interaction cutoff not supported. + + Arguments + --------- + main_graph: dict of torch.Tensor + Contains the main graph's edge_index and cell_offset. + The main graph defines which edges are embedded. + qint_graph: dict of torch.Tensor + Contains the quadruplet interaction graph's edge_index and + cell_offset. main_graph and qint_graph use the same atoms, + but different edges. + num_atoms: int + Total number of atoms. + + Returns + ------- + Dictionary containing the entries: + triplet_in['in']: torch.Tensor, shape (nTriplets,) + Indices of input edge d->b in triplet d->b->a. + triplet_in['out']: torch.Tensor, shape (nTriplets,) + Interaction indices of output edge b->a in triplet d->b->a. + triplet_out['in']: torch.Tensor, shape (nTriplets,) + Interaction indices of input edge b->a in triplet c->a<-b. + triplet_out['out']: torch.Tensor, shape (nTriplets,) + Indices of output edge c->a in triplet c->a<-b. + out: torch.Tensor, shape (nQuadruplets,) + Indices of output edge c->a in quadruplet + trip_in_to_quad: torch.Tensor, shape (nQuadruplets,) + Indices to map from input triplet d->b->a + to quadruplet d->b->a<-c. + trip_out_to_quad: torch.Tensor, shape (nQuadruplets,) + Indices to map from output triplet c->a<-b + to quadruplet d->b->a<-c. + out_agg: torch.Tensor, shape (num_triplets,) + Indices enumerating the intermediate edges of each output edge. + Used for creating a padded matrix and aggregating via matmul. + """ + idx_s, _ = main_graph["edge_index"] + idx_qint_s, _ = qint_graph["edge_index"] + # c->a (source=c, target=a) + num_edges = idx_s.size(0) + idx = {} + + idx["triplet_in"] = get_mixed_triplets( + main_graph, + qint_graph, + num_atoms, + to_outedge=True, + return_adj=True, + ) + # Input triplets d->b->a + + idx["triplet_out"] = get_mixed_triplets( + qint_graph, + main_graph, + num_atoms, + to_outedge=False, + ) + # Output triplets c->a<-b + + # ---------------- Quadruplets ----------------- + # Repeat indices by counting the number of input triplets per + # intermediate edge ba. segment_coo assumes sorted idx['triplet_in']['out'] + ones = idx["triplet_in"]["out"].new_ones(1).expand_as(idx["triplet_in"]["out"]) + num_trip_in_per_inter = segment_coo( + ones, idx["triplet_in"]["out"], dim_size=idx_qint_s.size(0) + ) + + num_trip_out_per_inter = num_trip_in_per_inter[idx["triplet_out"]["in"]] + idx["out"] = torch.repeat_interleave( + idx["triplet_out"]["out"], num_trip_out_per_inter + ) + idx_inter = torch.repeat_interleave( + idx["triplet_out"]["in"], num_trip_out_per_inter + ) + idx["trip_out_to_quad"] = torch.repeat_interleave( + torch.arange( + len(idx["triplet_out"]["out"]), + device=idx_s.device, + dtype=idx_s.dtype, + ), + num_trip_out_per_inter, + ) + + # Generate input indices by using the adjacency + # matrix idx['triplet_in']['adj_edges'] + idx["triplet_in"]["adj_edges"].set_value_( + torch.arange( + len(idx["triplet_in"]["in"]), + device=idx_s.device, + dtype=idx_s.dtype, + ), + layout="coo", + ) + adj_trip_in_per_trip_out = idx["triplet_in"]["adj_edges"][idx["triplet_out"]["in"]] + # Rows in adj_trip_in_per_trip_out are intermediate edges ba + idx["trip_in_to_quad"] = adj_trip_in_per_trip_out.storage.value() + idx_in = idx["triplet_in"]["in"][idx["trip_in_to_quad"]] + + # Remove quadruplets with c == d + # Triplets should already ensure that a != d and b != c + # Compare atom indices and cell offsets + idx_atom_c = idx_s[idx["out"]] + idx_atom_d = idx_s[idx_in] + + cell_offset_cd = ( + main_graph["cell_offset"][idx_in] + + qint_graph["cell_offset"][idx_inter] + - main_graph["cell_offset"][idx["out"]] + ) + mask_cd = (idx_atom_c != idx_atom_d) | torch.any(cell_offset_cd != 0, dim=-1) + + idx["out"] = idx["out"][mask_cd] + idx["trip_out_to_quad"] = idx["trip_out_to_quad"][mask_cd] + idx["trip_in_to_quad"] = idx["trip_in_to_quad"][mask_cd] + + # idx['out'] has to be sorted for this + idx["out_agg"] = get_inner_idx(idx["out"], dim_size=num_edges) + + return idx diff --git a/src/jmp/models/gemnet/layers/atom_update_block.py b/src/jmp/models/gemnet/layers/atom_update_block.py new file mode 100644 index 0000000..e516a8f --- /dev/null +++ b/src/jmp/models/gemnet/layers/atom_update_block.py @@ -0,0 +1,234 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math + +import torch +from torch_scatter import scatter + +from ....modules.scaling import ScaleFactor +from .base_layers import Dense, ResidualLayer + + +class AtomUpdateBlock(torch.nn.Module): + """ + Aggregate the message embeddings of the atoms + + Arguments + --------- + emb_size_atom: int + Embedding size of the atoms. + emb_size_edge: int + Embedding size of the edges. + emb_size_rbf: int + Embedding size of the radial basis. + nHidden: int + Number of residual blocks. + activation: callable/str + Name of the activation function to use in the dense layers. + """ + + def __init__( + self, + emb_size_atom: int, + emb_size_edge: int, + emb_size_rbf: int, + nHidden: int, + activation=None, + *, + dropout: float | None, + ): + super().__init__() + + self.dense_rbf = Dense( + emb_size_rbf, + emb_size_edge, + activation=None, + bias=False, + dropout=dropout, + ) + self.scale_sum = ScaleFactor() + + self.layers = self.get_mlp( + emb_size_edge, emb_size_atom, nHidden, activation, dropout + ) + + def get_mlp(self, units_in, units, nHidden, activation, dropout: float | None): + if units_in != units: + dense1 = Dense( + units_in, + units, + activation=activation, + bias=False, + dropout=dropout, + ) + mlp = [dense1] + else: + mlp = [] + res = [ + ResidualLayer(units, nLayers=2, activation=activation, dropout=dropout) + for i in range(nHidden) + ] + mlp += res + return torch.nn.ModuleList(mlp) + + def forward(self, h, m, basis_rad, idx_atom): + """ + Returns + ------- + h: torch.Tensor, shape=(nAtoms, emb_size_atom) + Atom embedding. + """ + nAtoms = h.shape[0] + + bases_emb = self.dense_rbf(basis_rad) # (nEdges, emb_size_edge) + x = m * bases_emb + + x2 = scatter( + x, idx_atom, dim=0, dim_size=nAtoms, reduce="sum" + ) # (nAtoms, emb_size_edge) + x = self.scale_sum(x2, ref=m) + + for layer in self.layers: + x = layer(x) # (nAtoms, emb_size_atom) + + return x + + +class OutputBlock(AtomUpdateBlock): + """ + Combines the atom update block and subsequent final dense layer. + + Arguments + --------- + emb_size_atom: int + Embedding size of the atoms. + emb_size_edge: int + Embedding size of the edges. + emb_size_rbf: int + Embedding size of the radial basis. + nHidden: int + Number of residual blocks before adding the atom embedding. + nHidden_afteratom: int + Number of residual blocks after adding the atom embedding. + activation: str + Name of the activation function to use in the dense layers. + direct_forces: bool + If true directly predict forces, i.e. without taking the gradient + of the energy potential. + """ + + def __init__( + self, + emb_size_atom: int, + emb_size_edge: int, + emb_size_rbf: int, + nHidden: int, + nHidden_afteratom: int, + activation=None, + direct_forces=True, + *, + edge_dropout: float | None, + dropout: float | None, + ): + super().__init__( + emb_size_atom=emb_size_atom, + emb_size_edge=emb_size_edge, + emb_size_rbf=emb_size_rbf, + nHidden=nHidden, + activation=activation, + dropout=dropout, + ) + + self.direct_forces = direct_forces + self.edge_dropout = edge_dropout + + self.seq_energy_pre = self.layers # inherited from parent class + if nHidden_afteratom >= 1: + self.seq_energy2 = self.get_mlp( + emb_size_atom, + emb_size_atom, + nHidden_afteratom, + activation, + dropout, + ) + self.inv_sqrt_2 = 1 / math.sqrt(2.0) + else: + self.seq_energy2 = None + + if self.direct_forces: + self.scale_rbf_F = ScaleFactor() + self.seq_forces = self.get_mlp( + emb_size_edge, + emb_size_edge, + nHidden, + activation, + dropout, + ) + self.dense_rbf_F = Dense( + emb_size_rbf, + emb_size_edge, + activation=None, + bias=False, + dropout=dropout, + ) + + def _drop_edge_boost_activations(self, x: torch.Tensor): + if not self.training or not self.edge_dropout: + return x + + x = x / (1 - self.edge_dropout) + return x + + def forward(self, h, m, basis_rad, idx_atom, *, data): + """ + Returns + ------- + torch.Tensor, shape=(nAtoms, emb_size_atom) + Output atom embeddings. + torch.Tensor, shape=(nEdges, emb_size_edge) + Output edge embeddings. + """ + nAtoms = h.shape[0] + + # ------------------------ Atom embeddings ------------------------ # + basis_emb_E = self.dense_rbf(basis_rad) # (nEdges, emb_size_edge) + x = m * basis_emb_E + + x_E = scatter( + x, idx_atom, dim=0, dim_size=nAtoms, reduce="sum" + ) # (nAtoms, emb_size_edge) + + x_E = self._drop_edge_boost_activations(x_E) + + x_E = self.scale_sum(x_E, ref=m) + + for layer in self.seq_energy_pre: + x_E = layer(x_E) # (nAtoms, emb_size_atom) + + if self.seq_energy2 is not None: + x_E = x_E + h + x_E = x_E * self.inv_sqrt_2 + for layer in self.seq_energy2: + x_E = layer(x_E) # (nAtoms, emb_size_atom) + + # ------------------------- Edge embeddings ------------------------ # + if self.direct_forces: + x_F = m + for i, layer in enumerate(self.seq_forces): + x_F = layer(x_F) # (nEdges, emb_size_edge) + + basis_emb_F = self.dense_rbf_F(basis_rad) + # (nEdges, emb_size_edge) + x_F_basis = x_F * basis_emb_F + x_F = self.scale_rbf_F(x_F_basis, ref=x_F) + else: + x_F = 0 + # ------------------------------------------------------------------ # + + return x_E, x_F diff --git a/src/jmp/models/gemnet/layers/base_layers.py b/src/jmp/models/gemnet/layers/base_layers.py new file mode 100644 index 0000000..1b2ac99 --- /dev/null +++ b/src/jmp/models/gemnet/layers/base_layers.py @@ -0,0 +1,129 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math +from collections.abc import Callable + +import torch.nn as nn + +from ..initializers import he_orthogonal_init + + +class Dense(nn.Module): + """ + Combines dense layer with scaling for silu activation. + + Arguments + --------- + in_features: int + Input embedding size. + out_features: int + Output embedding size. + bias: bool + True if use bias. + activation: str + Name of the activation function to use. + """ + + def __init__( + self, + in_features, + out_features, + bias=False, + activation=None, + scale_dim: bool = False, + *, + dropout: float | None, + ): + super().__init__() + + self.scale_dim = scale_dim + + self.linear = nn.Linear(in_features, out_features, bias=bias) + self.reset_parameters() + + if isinstance(activation, str): + activation = activation.lower() + if activation in ["scaled_silu", "scaled_swish"]: + self.activation = ScaledSiLU() + elif activation in ["silu", "swish"]: + self.activation = nn.SiLU() + elif activation is None: + self.activation = nn.Identity() + else: + raise NotImplementedError( + "Activation function not implemented for GemNet (yet)." + ) + + self.dropout = nn.Dropout(dropout) if dropout is not None else nn.Identity() + + def reset_parameters( + self, + initializer=he_orthogonal_init, + ): + initializer(self.linear.weight) + if self.linear.bias is not None: + _ = self.linear.bias.data.fill_(0) + + def forward(self, x): + x = self.linear(x) + x = self.activation(x) + x = self.dropout(x) + if self.scale_dim: + x = x * (self.linear.weight.shape[1] ** -0.5) + return x + + +class ScaledSiLU(nn.Module): + def __init__(self): + super().__init__() + self.scale_factor = 1 / 0.6 + self._activation = nn.SiLU() + + def forward(self, x): + return self._activation(x) * self.scale_factor + + +class ResidualLayer(nn.Module): + """ + Residual block with output scaled by 1/sqrt(2). + + Arguments + --------- + units: int + Input and output embedding size. + nLayers: int + Number of dense layers. + layer: nn.Module + Class for the layers inside the residual block. + layer_kwargs: str + Keyword arguments for initializing the layers. + """ + + def __init__( + self, + units: int, + nLayers: int = 2, + layer: Callable[..., nn.Module] = Dense, + **layer_kwargs, + ): + super().__init__() + + self.dense_mlp = nn.Sequential( + *[ + layer(in_features=units, out_features=units, bias=False, **layer_kwargs) + for _ in range(nLayers) + ] + ) + self.inv_sqrt_2 = 1 / math.sqrt(2) + + def forward(self, input): + x = self.dense_mlp(input) + x = input + x + x = x * self.inv_sqrt_2 + return x diff --git a/src/jmp/models/gemnet/layers/basis_utils.py b/src/jmp/models/gemnet/layers/basis_utils.py new file mode 100644 index 0000000..5c091ab --- /dev/null +++ b/src/jmp/models/gemnet/layers/basis_utils.py @@ -0,0 +1,321 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import sympy as sym +import torch +from scipy import special as sp +from scipy.optimize import brentq + + +def Jn(r, n): + """ + numerical spherical bessel functions of order n + """ + return sp.spherical_jn(n, r) + + +def Jn_zeros(n, k): + """ + Compute the first k zeros of the spherical bessel functions + up to order n (excluded) + """ + zerosj = np.zeros((n, k), dtype="float32") + zerosj[0] = np.arange(1, k + 1) * np.pi + points = np.arange(1, k + n) * np.pi + racines = np.zeros(k + n - 1, dtype="float32") + for i in range(1, n): + for j in range(k + n - 1 - i): + foo = brentq(Jn, points[j], points[j + 1], (i,)) + racines[j] = foo + points = racines + zerosj[i][:k] = racines[:k] + + return zerosj + + +def spherical_bessel_formulas(n): + """ + Computes the sympy formulas for the spherical bessel functions + up to order n (excluded) + """ + x = sym.symbols("x", real=True) + # j_i = (-x)^i * (1/x * d/dx)^î * sin(x)/x + j = [sym.sin(x) / x] # j_0 + a = sym.sin(x) / x + for i in range(1, n): + b = sym.diff(a, x) / x + j += [sym.simplify(b * (-x) ** i)] + a = sym.simplify(b) + return j + + +def bessel_basis(n, k): + """ + Compute the sympy formulas for the normalized and rescaled spherical bessel + functions up to order n (excluded) and maximum frequency k (excluded). + + Returns + ------- + bess_basis: list + Bessel basis formulas taking in a single argument x. + Has length n where each element has length k. -> In total n*k many. + """ + zeros = Jn_zeros(n, k) + normalizer = [] + for order in range(n): + normalizer_tmp = [] + for i in range(k): + normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2] + normalizer_tmp = ( + 1 / np.array(normalizer_tmp) ** 0.5 + ) # sqrt(2/(j_l+1)**2) , sqrt(1/c**3) not taken into account yet + normalizer += [normalizer_tmp] + + f = spherical_bessel_formulas(n) + x = sym.symbols("x", real=True) + bess_basis = [] + for order in range(n): + bess_basis_tmp = [] + for i in range(k): + bess_basis_tmp += [ + sym.simplify( + normalizer[order][i] * f[order].subs(x, zeros[order, i] * x) + ) + ] + bess_basis += [bess_basis_tmp] + return bess_basis + + +def sph_harm_prefactor(l_degree, m_order): + """ + Computes the constant pre-factor for the spherical harmonic + of degree l and order m. + + Arguments + --------- + l_degree: int + Degree of the spherical harmonic. l >= 0 + m_order: int + Order of the spherical harmonic. -l <= m <= l + + Returns + ------- + factor: float + + """ + # sqrt((2*l+1)/4*pi * (l-m)!/(l+m)! ) + return ( + (2 * l_degree + 1) + / (4 * np.pi) + * np.math.factorial(l_degree - abs(m_order)) + / np.math.factorial(l_degree + abs(m_order)) + ) ** 0.5 + + +def associated_legendre_polynomials(L_maxdegree, zero_m_only=True, pos_m_only=True): + """ + Computes string formulas of the associated legendre polynomials + up to degree L (excluded). + + Arguments + --------- + L_maxdegree: int + Degree up to which to calculate the associated legendre polynomials + (degree L is excluded). + zero_m_only: bool + If True only calculate the polynomials for the polynomials where m=0. + pos_m_only: bool + If True only calculate the polynomials for the polynomials where m>=0. + Overwritten by zero_m_only. + + Returns + ------- + polynomials: list + Contains the sympy functions of the polynomials + (in total L many if zero_m_only is True else L^2 many). + """ + # calculations from http://web.cmb.usc.edu/people/alber/Software/tomominer/docs/cpp/group__legendre__polynomials.html + z = sym.symbols("z", real=True) + P_l_m = [ + [0] * (2 * l_degree + 1) for l_degree in range(L_maxdegree) + ] # for order l: -l <= m <= l + + P_l_m[0][0] = 1 + if L_maxdegree > 1: + if zero_m_only: + # m = 0 + P_l_m[1][0] = z + for l_degree in range(2, L_maxdegree): + P_l_m[l_degree][0] = sym.simplify( + ( + (2 * l_degree - 1) * z * P_l_m[l_degree - 1][0] + - (l_degree - 1) * P_l_m[l_degree - 2][0] + ) + / l_degree + ) + return P_l_m + else: + # for m >= 0 + for l_degree in range(1, L_maxdegree): + P_l_m[l_degree][l_degree] = sym.simplify( + (1 - 2 * l_degree) + * (1 - z**2) ** 0.5 + * P_l_m[l_degree - 1][l_degree - 1] + ) # P_00, P_11, P_22, P_33 + + for m_order in range(0, L_maxdegree - 1): + P_l_m[m_order + 1][m_order] = sym.simplify( + (2 * m_order + 1) * z * P_l_m[m_order][m_order] + ) # P_10, P_21, P_32, P_43 + + for l_degree in range(2, L_maxdegree): + for m_order in range(l_degree - 1): # P_20, P_30, P_31 + P_l_m[l_degree][m_order] = sym.simplify( + ( + (2 * l_degree - 1) * z * P_l_m[l_degree - 1][m_order] + - (l_degree + m_order - 1) * P_l_m[l_degree - 2][m_order] + ) + / (l_degree - m_order) + ) + + if not pos_m_only: + # for m < 0: P_l(-m) = (-1)^m * (l-m)!/(l+m)! * P_lm + for l_degree in range(1, L_maxdegree): + for m_order in range(1, l_degree + 1): # P_1(-1), P_2(-1) P_2(-2) + P_l_m[l_degree][-m_order] = sym.simplify( + (-1) ** m_order + * np.math.factorial(l_degree - m_order) + / np.math.factorial(l_degree + m_order) + * P_l_m[l_degree][m_order] + ) + + return P_l_m + + +def real_sph_harm(L_maxdegree, use_theta, use_phi=True, zero_m_only=True): + """ + Computes formula strings of the the real part of the spherical harmonics + up to degree L (excluded). Variables are either spherical coordinates phi + and theta (or cartesian coordinates x,y,z) on the UNIT SPHERE. + + Arguments + --------- + L_maxdegree: int + Degree up to which to calculate the spherical harmonics + (degree L is excluded). + use_theta: bool + - True: Expects the input of the formula strings to contain theta. + - False: Expects the input of the formula strings to contain z. + use_phi: bool + - True: Expects the input of the formula strings to contain phi. + - False: Expects the input of the formula strings to contain x and y. + Does nothing if zero_m_only is True + zero_m_only: bool + If True only calculate the harmonics where m=0. + + Returns + ------- + Y_lm_real: list + Computes formula strings of the the real part of the spherical + harmonics up to degree L (where degree L is not excluded). + In total L^2 many sph harm exist up to degree L (excluded). + However, if zero_m_only only is True then the total count + is reduced to L. + """ + z = sym.symbols("z", real=True) + P_l_m = associated_legendre_polynomials(L_maxdegree, zero_m_only) + if zero_m_only: + # for all m != 0: Y_lm = 0 + Y_l_m = [[0] for l_degree in range(L_maxdegree)] + else: + Y_l_m = [ + [0] * (2 * l_degree + 1) for l_degree in range(L_maxdegree) + ] # for order l: -l <= m <= l + + # convert expressions to spherical coordiantes + if use_theta: + # replace z by cos(theta) + theta = sym.symbols("theta", real=True) + for l_degree in range(L_maxdegree): + for m_order in range(len(P_l_m[l_degree])): + if not isinstance(P_l_m[l_degree][m_order], int): + P_l_m[l_degree][m_order] = P_l_m[l_degree][m_order].subs( + z, sym.cos(theta) + ) + + ## calculate Y_lm + # Y_lm = N * P_lm(cos(theta)) * exp(i*m*phi) + # { sqrt(2) * (-1)^m * N * P_l|m| * sin(|m|*phi) if m < 0 + # Y_lm_real = { Y_lm if m = 0 + # { sqrt(2) * (-1)^m * N * P_lm * cos(m*phi) if m > 0 + + for l_degree in range(L_maxdegree): + Y_l_m[l_degree][0] = sym.simplify( + sph_harm_prefactor(l_degree, 0) * P_l_m[l_degree][0] + ) # Y_l0 + + if not zero_m_only: + phi = sym.symbols("phi", real=True) + for l_degree in range(1, L_maxdegree): + # m > 0 + for m_order in range(1, l_degree + 1): + Y_l_m[l_degree][m_order] = sym.simplify( + 2**0.5 + * (-1) ** m_order + * sph_harm_prefactor(l_degree, m_order) + * P_l_m[l_degree][m_order] + * sym.cos(m_order * phi) + ) + # m < 0 + for m_order in range(1, l_degree + 1): + Y_l_m[l_degree][-m_order] = sym.simplify( + 2**0.5 + * (-1) ** m_order + * sph_harm_prefactor(l_degree, -m_order) + * P_l_m[l_degree][m_order] + * sym.sin(m_order * phi) + ) + + # convert expressions to cartesian coordinates + if not use_phi: + # replace phi by atan2(y,x) + x, y = sym.symbols("x y", real=True) + for l_degree in range(L_maxdegree): + for m_order in range(len(Y_l_m[l_degree])): + Y_l_m[l_degree][m_order] = sym.simplify( + Y_l_m[l_degree][m_order].subs(phi, sym.atan2(y, x)) + ) + return Y_l_m + + +def get_sph_harm_basis(L_maxdegree, zero_m_only=True): + """Get a function calculating the spherical harmonics basis from z and phi.""" + # retrieve equations + Y_lm = real_sph_harm( + L_maxdegree, use_theta=False, use_phi=True, zero_m_only=zero_m_only + ) + Y_lm_flat = [Y for Y_l in Y_lm for Y in Y_l] + + # convert to pytorch functions + z = sym.symbols("z", real=True) + variables = [z] + if not zero_m_only: + variables.append(sym.symbols("phi", real=True)) + + modules = {"sin": torch.sin, "cos": torch.cos, "sqrt": torch.sqrt} + sph_funcs = sym.lambdify(variables, Y_lm_flat, modules) + + # Return as a single function + # args are either [cosφ] or [cosφ, ϑ] + def basis_fn(*args): + basis = sph_funcs(*args) + basis[0] = args[0].new_tensor(basis[0]).expand_as(args[0]) + return torch.stack(basis, dim=1) + + return basis_fn diff --git a/src/jmp/models/gemnet/layers/efficient.py b/src/jmp/models/gemnet/layers/efficient.py new file mode 100644 index 0000000..d5ac812 --- /dev/null +++ b/src/jmp/models/gemnet/layers/efficient.py @@ -0,0 +1,265 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Optional + +import torch + +from ..initializers import he_orthogonal_init +from .base_layers import Dense + + +class BasisEmbedding(torch.nn.Module): + """ + Embed a basis (CBF, SBF), optionally using the efficient reformulation. + + Arguments + --------- + num_radial: int + Number of radial basis functions. + emb_size_interm: int + Intermediate embedding size of triplets/quadruplets. + num_spherical: int + Number of circular/spherical basis functions. + Only required if there is a circular/spherical basis. + """ + + def __init__( + self, + num_radial: int, + emb_size_interm: int, + num_spherical: Optional[int] = None, + ): + super().__init__() + self.num_radial = num_radial + self.num_spherical = num_spherical + if num_spherical is None: + self.weight = torch.nn.Parameter( + torch.empty(emb_size_interm, num_radial), + requires_grad=True, + ) + else: + self.weight = torch.nn.Parameter( + torch.empty(num_radial, num_spherical, emb_size_interm), + requires_grad=True, + ) + self.reset_parameters() + + def reset_parameters(self): + he_orthogonal_init(self.weight) + + def forward( + self, + rad_basis, + sph_basis=None, + idx_rad_outer=None, + idx_rad_inner=None, + idx_sph_outer=None, + idx_sph_inner=None, + num_atoms=None, + ): + """ + + Arguments + --------- + rad_basis: torch.Tensor, shape=(num_edges, num_radial or num_orders * num_radial) + Raw radial basis. + sph_basis: torch.Tensor, shape=(num_triplets or num_quadruplets, num_spherical) + Raw spherical or circular basis. + idx_rad_outer: torch.Tensor, shape=(num_edges) + Atom associated with each radial basis value. + Optional, used for efficient edge aggregation. + idx_rad_inner: torch.Tensor, shape=(num_edges) + Enumerates radial basis values per atom. + Optional, used for efficient edge aggregation. + idx_sph_outer: torch.Tensor, shape=(num_triplets or num_quadruplets) + Edge associated with each circular/spherical basis value. + Optional, used for efficient triplet/quadruplet aggregation. + idx_sph_inner: torch.Tensor, shape=(num_triplets or num_quadruplets) + Enumerates circular/spherical basis values per edge. + Optional, used for efficient triplet/quadruplet aggregation. + num_atoms: int + Total number of atoms. + Optional, used for efficient edge aggregation. + + Returns + ------- + rad_W1: torch.Tensor, shape=(num_edges, emb_size_interm, num_spherical) + sph: torch.Tensor, shape=(num_edges, Kmax, num_spherical) + Kmax = maximum number of neighbors of the edges + """ + num_edges = rad_basis.shape[0] + + if self.num_spherical is not None: + # MatMul: mul + sum over num_radial + rad_W1 = rad_basis @ self.weight.reshape(self.weight.shape[0], -1) + # (num_edges, emb_size_interm * num_spherical) + rad_W1 = rad_W1.reshape(num_edges, -1, sph_basis.shape[-1]) + # (num_edges, emb_size_interm, num_spherical) + else: + # MatMul: mul + sum over num_radial + rad_W1 = rad_basis @ self.weight.T + # (num_edges, emb_size_interm) + + if idx_rad_inner is not None: + # Zero padded dense matrix + # maximum number of neighbors + if idx_rad_outer.shape[0] == 0: + # catch empty idx_rad_outer + Kmax = 0 + else: + Kmax = torch.max(idx_rad_inner) + 1 + + rad_W1_padded = rad_W1.new_zeros([num_atoms, Kmax] + list(rad_W1.shape[1:])) + rad_W1_padded[idx_rad_outer, idx_rad_inner] = rad_W1 + # (num_atoms, Kmax, emb_size_interm, ...) + rad_W1_padded = torch.transpose(rad_W1_padded, 1, 2) + # (num_atoms, emb_size_interm, Kmax, ...) + rad_W1_padded = rad_W1_padded.reshape(num_atoms, rad_W1.shape[1], -1) + # (num_atoms, emb_size_interm, Kmax2 * ...) + rad_W1 = rad_W1_padded + + if idx_sph_inner is not None: + # Zero padded dense matrix + # maximum number of neighbors + if idx_sph_outer.shape[0] == 0: + # catch empty idx_sph_outer + Kmax = 0 + else: + Kmax = torch.max(idx_sph_inner) + 1 + + sph2 = sph_basis.new_zeros(num_edges, Kmax, sph_basis.shape[-1]) + sph2[idx_sph_outer, idx_sph_inner] = sph_basis + # (num_edges, Kmax, num_spherical) + sph2 = torch.transpose(sph2, 1, 2) + # (num_edges, num_spherical, Kmax) + + if sph_basis is None: + return rad_W1 + else: + if idx_sph_inner is None: + rad_W1 = rad_W1[idx_sph_outer] + # (num_triplets, emb_size_interm, num_spherical) + + sph_W1 = rad_W1 @ sph_basis[:, :, None] + # (num_triplets, emb_size_interm, num_spherical) + return sph_W1.squeeze(-1) + else: + return rad_W1, sph2 + + +class EfficientInteractionBilinear(torch.nn.Module): + """ + Efficient reformulation of the bilinear layer and subsequent summation. + + Arguments + --------- + emb_size_in: int + Embedding size of input triplets/quadruplets. + emb_size_interm: int + Intermediate embedding size of the basis transformation. + emb_size_out: int + Embedding size of output triplets/quadruplets. + """ + + def __init__( + self, + emb_size_in: int, + emb_size_interm: int, + emb_size_out: int, + *, + dropout: float | None, + ): + super().__init__() + self.emb_size_in = emb_size_in + self.emb_size_interm = emb_size_interm + self.emb_size_out = emb_size_out + + self.bilinear = Dense( + self.emb_size_in * self.emb_size_interm, + self.emb_size_out, + bias=False, + activation=None, + dropout=dropout, + ) + + def forward( + self, + basis, + m, + idx_agg_outer, + idx_agg_inner, + idx_agg2_outer=None, + idx_agg2_inner=None, + agg2_out_size=None, + ): + """ + + Arguments + --------- + basis: Tuple (torch.Tensor, torch.Tensor), + shapes=((num_edges, emb_size_interm, num_spherical), + (num_edges, num_spherical, Kmax)) + First element: Radial basis multiplied with weight matrix + Second element: Circular/spherical basis + m: torch.Tensor, shape=(num_edges, emb_size_in) + Input edge embeddings + idx_agg_outer: torch.Tensor, shape=(num_triplets or num_quadruplets) + Output edge aggregating this intermediate triplet/quadruplet edge. + idx_agg_inner: torch.Tensor, shape=(num_triplets or num_quadruplets) + Enumerates intermediate edges per output edge. + idx_agg2_outer: torch.Tensor, shape=(num_edges) + Output atom aggregating this edge. + idx_agg2_inner: torch.Tensor, shape=(num_edges) + Enumerates edges per output atom. + agg2_out_size: int + Number of output embeddings when aggregating twice. Typically + the number of atoms. + + Returns + ------- + m_ca: torch.Tensor, shape=(num_edges, emb_size) + Aggregated edge/atom embeddings. + """ + # num_spherical is actually num_spherical**2 for quadruplets + (rad_W1, sph) = basis + # (num_edges, emb_size_interm, num_spherical), + # (num_edges, num_spherical, Kmax) + num_edges = sph.shape[0] + + # Create (zero-padded) dense matrix of the neighboring edge embeddings. + Kmax = torch.max(idx_agg_inner) + 1 + m_padded = m.new_zeros(num_edges, Kmax, self.emb_size_in) + m_padded[idx_agg_outer, idx_agg_inner] = m + # (num_quadruplets/num_triplets, emb_size_in) -> (num_edges, Kmax, emb_size_in) + + sph_m = torch.matmul(sph, m_padded) + # (num_edges, num_spherical, emb_size_in) + + if idx_agg2_outer is not None: + Kmax2 = torch.max(idx_agg2_inner) + 1 + sph_m_padded = sph_m.new_zeros( + agg2_out_size, Kmax2, sph_m.shape[1], sph_m.shape[2] + ) + sph_m_padded[idx_agg2_outer, idx_agg2_inner] = sph_m + # (num_atoms, Kmax2, num_spherical, emb_size_in) + sph_m_padded = sph_m_padded.reshape(agg2_out_size, -1, sph_m.shape[-1]) + # (num_atoms, Kmax2 * num_spherical, emb_size_in) + + rad_W1_sph_m = rad_W1 @ sph_m_padded + # (num_atoms, emb_size_interm, emb_size_in) + else: + # MatMul: mul + sum over num_spherical + rad_W1_sph_m = torch.matmul(rad_W1, sph_m) + # (num_edges, emb_size_interm, emb_size_in) + + # Bilinear: Sum over emb_size_interm and emb_size_in + m_ca = self.bilinear(rad_W1_sph_m.reshape(-1, rad_W1_sph_m.shape[1:].numel())) + # (num_edges/num_atoms, emb_size_out) + + return m_ca diff --git a/src/jmp/models/gemnet/layers/embedding_block.py b/src/jmp/models/gemnet/layers/embedding_block.py new file mode 100644 index 0000000..bca7b57 --- /dev/null +++ b/src/jmp/models/gemnet/layers/embedding_block.py @@ -0,0 +1,105 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch + +from .base_layers import Dense + + +class AtomEmbedding(torch.nn.Module): + """ + Initial atom embeddings based on the atom type + + Arguments + --------- + emb_size: int + Atom embeddings size + """ + + def __init__(self, emb_size, num_elements): + super().__init__() + self.emb_size = emb_size + + self.embeddings = torch.nn.Embedding(num_elements, emb_size) + # init by uniform distribution + torch.nn.init.uniform_(self.embeddings.weight, a=-np.sqrt(3), b=np.sqrt(3)) + + def forward(self, Z): + """ + Returns + ------- + h: torch.Tensor, shape=(nAtoms, emb_size) + Atom embeddings. + """ + h = self.embeddings(Z - 1) # -1 because Z.min()=1 (==Hydrogen) + return h + + +class EdgeEmbedding(torch.nn.Module): + """ + Edge embedding based on the concatenation of atom embeddings + and a subsequent dense layer. + + Arguments + --------- + atom_features: int + Embedding size of the atom embedding. + edge_features: int + Embedding size of the input edge embedding. + out_features: int + Embedding size after the dense layer. + activation: str + Activation function used in the dense layer. + """ + + def __init__( + self, + atom_features, + edge_features, + out_features, + activation=None, + *, + dropout: float | None, + ): + super().__init__() + in_features = 2 * atom_features + edge_features + self.dense = Dense( + in_features, + out_features, + activation=activation, + bias=False, + dropout=dropout, + ) + + def forward( + self, + h, + m, + edge_index, + ): + """ + Arguments + --------- + h: torch.Tensor, shape (num_atoms, atom_features) + Atom embeddings. + m: torch.Tensor, shape (num_edges, edge_features) + Radial basis in embedding block, + edge embedding in interaction block. + + Returns + ------- + m_st: torch.Tensor, shape=(nEdges, emb_size) + Edge embeddings. + """ + h_s = h[edge_index[0]] # shape=(nEdges, emb_size) + h_t = h[edge_index[1]] # shape=(nEdges, emb_size) + + m_st = torch.cat([h_s, h_t, m], dim=-1) # (nEdges, 2*emb_size+nFeatures) + m_st = self.dense(m_st) # (nEdges, emb_size) + return m_st diff --git a/src/jmp/models/gemnet/layers/force_scaler.py b/src/jmp/models/gemnet/layers/force_scaler.py new file mode 100644 index 0000000..f72cf3a --- /dev/null +++ b/src/jmp/models/gemnet/layers/force_scaler.py @@ -0,0 +1,96 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging + +import torch + + +class ForceScaler: + """ + Scales up the energy and then scales down the forces + to prevent NaNs and infs in calculations using AMP. + Inspired by torch.cuda.amp.GradScaler. + """ + + def __init__( + self, + init_scale=2.0**8, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + max_force_iters=50, + enabled=True, + ): + self.scale_factor = init_scale + self.growth_factor = growth_factor + self.backoff_factor = backoff_factor + self.growth_interval = growth_interval + self.max_force_iters = max_force_iters + self.enabled = enabled + self.finite_force_results = 0 + + def scale(self, energy): + return energy * self.scale_factor if self.enabled else energy + + def unscale(self, forces): + return forces / self.scale_factor if self.enabled else forces + + def calc_forces(self, energy, pos): + energy_scaled = self.scale(energy) + forces_scaled = -torch.autograd.grad( + energy_scaled, + pos, + grad_outputs=torch.ones_like(energy_scaled), + create_graph=True, + )[0] + # (nAtoms, 3) + forces = self.unscale(forces_scaled) + return forces + + def calc_forces_and_update(self, energy, pos): + if self.enabled: + found_nans_or_infs = True + force_iters = 0 + + # Re-calculate forces until everything is nice and finite. + while found_nans_or_infs: + forces = self.calc_forces(energy, pos) + + found_nans_or_infs = not torch.all(forces.isfinite()) + if found_nans_or_infs: + self.finite_force_results = 0 + + # Prevent infinite loop + force_iters += 1 + if force_iters == self.max_force_iters: + logging.warning( + "Too many non-finite force results in a batch. " + "Breaking scaling loop." + ) + break + else: + # Delete graph to save memory + del forces + else: + self.finite_force_results += 1 + self.update() + else: + forces = self.calc_forces(energy, pos) + return forces + + def update(self): + if self.finite_force_results == 0: + self.scale_factor *= self.backoff_factor + + if self.finite_force_results == self.growth_interval: + self.scale_factor *= self.growth_factor + self.finite_force_results = 0 + + logging.info(f"finite force step count: {self.finite_force_results}") + logging.info(f"scaling factor: {self.scale_factor}") diff --git a/src/jmp/models/gemnet/layers/interaction_block.py b/src/jmp/models/gemnet/layers/interaction_block.py new file mode 100644 index 0000000..c708611 --- /dev/null +++ b/src/jmp/models/gemnet/layers/interaction_block.py @@ -0,0 +1,786 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math + +import torch + +from ....modules.scaling import ScaleFactor +from .atom_update_block import AtomUpdateBlock +from .base_layers import Dense, ResidualLayer +from .efficient import EfficientInteractionBilinear +from .embedding_block import EdgeEmbedding + + +class InteractionBlock(torch.nn.Module): + """ + Interaction block for GemNet-Q/dQ. + + Arguments + --------- + emb_size_atom: int + Embedding size of the atoms. + emb_size_edge: int + Embedding size of the edges. + emb_size_trip_in: int + (Down-projected) embedding size of the quadruplet edge embeddings + before the bilinear layer. + emb_size_trip_out: int + (Down-projected) embedding size of the quadruplet edge embeddings + after the bilinear layer. + emb_size_quad_in: int + (Down-projected) embedding size of the quadruplet edge embeddings + before the bilinear layer. + emb_size_quad_out: int + (Down-projected) embedding size of the quadruplet edge embeddings + after the bilinear layer. + emb_size_a2a_in: int + Embedding size in the atom interaction before the bilinear layer. + emb_size_a2a_out: int + Embedding size in the atom interaction after the bilinear layer. + emb_size_rbf: int + Embedding size of the radial basis transformation. + emb_size_cbf: int + Embedding size of the circular basis transformation (one angle). + emb_size_sbf: int + Embedding size of the spherical basis transformation (two angles). + num_before_skip: int + Number of residual blocks before the first skip connection. + num_after_skip: int + Number of residual blocks after the first skip connection. + num_concat: int + Number of residual blocks after the concatenation. + num_atom: int + Number of residual blocks in the atom embedding blocks. + num_atom_emb_layers: int + Number of residual blocks for transforming atom embeddings. + quad_interaction: bool + Whether to use quadruplet interactions. + atom_edge_interaction: bool + Whether to use atom-to-edge interactions. + edge_atom_interaction: bool + Whether to use edge-to-atom interactions. + atom_interaction: bool + Whether to use atom-to-atom interactions. + activation: str + Name of the activation function to use in the dense layers. + """ + + def __init__( + self, + emb_size_atom, + emb_size_edge, + emb_size_trip_in, + emb_size_trip_out, + emb_size_quad_in, + emb_size_quad_out, + emb_size_a2a_in, + emb_size_a2a_out, + emb_size_rbf, + emb_size_cbf, + emb_size_sbf, + num_before_skip, + num_after_skip, + num_concat, + num_atom, + num_atom_emb_layers=0, + quad_interaction=False, + atom_edge_interaction=False, + edge_atom_interaction=False, + atom_interaction=False, + activation=None, + *, + dropout: float | None, + ): + super().__init__() + + ## ------------------------ Message Passing ----------------------- ## + # Dense transformation of skip connection + self.dense_ca = Dense( + emb_size_edge, + emb_size_edge, + activation=activation, + bias=False, + dropout=dropout, + ) + + # Triplet Interaction + self.trip_interaction = TripletInteraction( + emb_size_in=emb_size_edge, + emb_size_out=emb_size_edge, + emb_size_trip_in=emb_size_trip_in, + emb_size_trip_out=emb_size_trip_out, + emb_size_rbf=emb_size_rbf, + emb_size_cbf=emb_size_cbf, + symmetric_mp=True, + swap_output=True, + activation=activation, + dropout=dropout, + ) + + # Quadruplet Interaction + if quad_interaction: + self.quad_interaction = QuadrupletInteraction( + emb_size_edge=emb_size_edge, + emb_size_quad_in=emb_size_quad_in, + emb_size_quad_out=emb_size_quad_out, + emb_size_rbf=emb_size_rbf, + emb_size_cbf=emb_size_cbf, + emb_size_sbf=emb_size_sbf, + symmetric_mp=True, + activation=activation, + dropout=dropout, + ) + else: + self.quad_interaction = None + + if atom_edge_interaction: + self.atom_edge_interaction = TripletInteraction( + emb_size_in=emb_size_atom, + emb_size_out=emb_size_edge, + emb_size_trip_in=emb_size_trip_in, + emb_size_trip_out=emb_size_trip_out, + emb_size_rbf=emb_size_rbf, + emb_size_cbf=emb_size_cbf, + symmetric_mp=True, + swap_output=True, + activation=activation, + dropout=dropout, + ) + else: + self.atom_edge_interaction = None + if edge_atom_interaction: + self.edge_atom_interaction = TripletInteraction( + emb_size_in=emb_size_edge, + emb_size_out=emb_size_atom, + emb_size_trip_in=emb_size_trip_in, + emb_size_trip_out=emb_size_trip_out, + emb_size_rbf=emb_size_rbf, + emb_size_cbf=emb_size_cbf, + symmetric_mp=False, + swap_output=False, + activation=activation, + dropout=dropout, + ) + else: + self.edge_atom_interaction = None + if atom_interaction: + self.atom_interaction = PairInteraction( + emb_size_atom=emb_size_atom, + emb_size_pair_in=emb_size_a2a_in, + emb_size_pair_out=emb_size_a2a_out, + emb_size_rbf=emb_size_rbf, + activation=activation, + dropout=dropout, + ) + else: + self.atom_interaction = None + + ## -------------------- Update Edge Embeddings -------------------- ## + # Residual layers before skip connection + self.layers_before_skip = torch.nn.ModuleList( + [ + ResidualLayer(emb_size_edge, activation=activation, dropout=dropout) + for i in range(num_before_skip) + ] + ) + + # Residual layers after skip connection + self.layers_after_skip = torch.nn.ModuleList( + [ + ResidualLayer(emb_size_edge, activation=activation, dropout=dropout) + for i in range(num_after_skip) + ] + ) + + ## -------------------- Update Atom Embeddings -------------------- ## + self.atom_emb_layers = torch.nn.ModuleList( + [ + ResidualLayer(emb_size_atom, activation=activation, dropout=dropout) + for _ in range(num_atom_emb_layers) + ] + ) + + self.atom_update = AtomUpdateBlock( + emb_size_atom=emb_size_atom, + emb_size_edge=emb_size_edge, + emb_size_rbf=emb_size_rbf, + nHidden=num_atom, + activation=activation, + dropout=dropout, + ) + + ## ---------- Update Edge Embeddings with Atom Embeddings --------- ## + self.concat_layer = EdgeEmbedding( + emb_size_atom, + emb_size_edge, + emb_size_edge, + activation=activation, + dropout=dropout, + ) + self.residual_m = torch.nn.ModuleList( + [ + ResidualLayer(emb_size_edge, activation=activation, dropout=dropout) + for _ in range(num_concat) + ] + ) + + self.inv_sqrt_2 = 1 / math.sqrt(2.0) + num_eint = 2.0 + quad_interaction + atom_edge_interaction + self.inv_sqrt_num_eint = 1 / math.sqrt(num_eint) + num_aint = 1.0 + edge_atom_interaction + atom_interaction + self.inv_sqrt_num_aint = 1 / math.sqrt(num_aint) + + def forward( + self, + h, + m, + bases_qint, + bases_e2e, + bases_a2e, + bases_e2a, + basis_a2a_rad, + basis_atom_update, + edge_index_main, + a2ee2a_graph, + a2a_graph, + id_swap, + trip_idx_e2e, + trip_idx_a2e, + trip_idx_e2a, + quad_idx, + ): + """ + Returns + ------- + h: torch.Tensor, shape=(nEdges, emb_size_atom) + Atom embeddings. + m: torch.Tensor, shape=(nEdges, emb_size_edge) + Edge embeddings (c->a). + """ + num_atoms = h.shape[0] + + # Initial transformation + x_ca_skip = self.dense_ca(m) # (nEdges, emb_size_edge) + + x_e2e = self.trip_interaction( + m, + bases_e2e, + trip_idx_e2e, + id_swap, + ) + if self.quad_interaction is not None: + x_qint = self.quad_interaction( + m, + bases_qint, + quad_idx, + id_swap, + ) + if self.atom_edge_interaction is not None: + x_a2e = self.atom_edge_interaction( + h, + bases_a2e, + trip_idx_a2e, + id_swap, + expand_idx=a2ee2a_graph["edge_index"][0], + ) + if self.edge_atom_interaction is not None: + h_e2a = self.edge_atom_interaction( + m, + bases_e2a, + trip_idx_e2a, + id_swap, + idx_agg2=a2ee2a_graph["edge_index"][1], + idx_agg2_inner=a2ee2a_graph["target_neighbor_idx"], + agg2_out_size=num_atoms, + ) + if self.atom_interaction is not None: + h_a2a = self.atom_interaction( + h, + basis_a2a_rad, + a2a_graph["edge_index"], + a2a_graph["target_neighbor_idx"], + ) + + ## -------------- Merge Embeddings after interactions ------------- ## + x = x_ca_skip + x_e2e # (nEdges, emb_size_edge) + if self.quad_interaction is not None: + x += x_qint # (nEdges, emb_size_edge) + if self.atom_edge_interaction is not None: + x += x_a2e # (nEdges, emb_size_edge) + x = x * self.inv_sqrt_num_eint + + # Merge atom embeddings after interactions + if self.edge_atom_interaction is not None: + h = h + h_e2a # (nEdges, emb_size_edge) + if self.atom_interaction is not None: + h = h + h_a2a # (nEdges, emb_size_edge) + h = h * self.inv_sqrt_num_aint + + ## -------------------- Update Edge Embeddings -------------------- ## + # Transformations before skip connection + for i, layer in enumerate(self.layers_before_skip): + x = layer(x) # (nEdges, emb_size_edge) + + # Skip connection + m = m + x # (nEdges, emb_size_edge) + m = m * self.inv_sqrt_2 + + # Transformations after skip connection + for i, layer in enumerate(self.layers_after_skip): + m = layer(m) # (nEdges, emb_size_edge) + + ## -------------------- Update Atom Embeddings -------------------- ## + for layer in self.atom_emb_layers: + h = layer(h) # (nAtoms, emb_size_atom) + + h2 = self.atom_update(h, m, basis_atom_update, edge_index_main[1]) + + # Skip connection + h = h + h2 # (nAtoms, emb_size_atom) + h = h * self.inv_sqrt_2 + + ## ---------- Update Edge Embeddings with Atom Embeddings --------- ## + m2 = self.concat_layer(h, m, edge_index_main) + # (nEdges, emb_size_edge) + + for i, layer in enumerate(self.residual_m): + m2 = layer(m2) # (nEdges, emb_size_edge) + + # Skip connection + m = m + m2 # (nEdges, emb_size_edge) + m = m * self.inv_sqrt_2 + return h, m + + +class QuadrupletInteraction(torch.nn.Module): + """ + Quadruplet-based message passing block. + + Arguments + --------- + emb_size_edge: int + Embedding size of the edges. + emb_size_quad_in: int + (Down-projected) embedding size of the quadruplet edge embeddings + before the bilinear layer. + emb_size_quad_out: int + (Down-projected) embedding size of the quadruplet edge embeddings + after the bilinear layer. + emb_size_rbf: int + Embedding size of the radial basis transformation. + emb_size_cbf: int + Embedding size of the circular basis transformation (one angle). + emb_size_sbf: int + Embedding size of the spherical basis transformation (two angles). + symmetric_mp: bool + Whether to use symmetric message passing and + update the edges in both directions. + activation: str + Name of the activation function to use in the dense layers. + """ + + def __init__( + self, + emb_size_edge, + emb_size_quad_in, + emb_size_quad_out, + emb_size_rbf, + emb_size_cbf, + emb_size_sbf, + symmetric_mp=True, + activation=None, + *, + dropout: float | None, + ): + super().__init__() + self.symmetric_mp = symmetric_mp + + # Dense transformation + self.dense_db = Dense( + emb_size_edge, + emb_size_edge, + activation=activation, + bias=False, + dropout=dropout, + ) + + # Up projections of basis representations, + # bilinear layer and scaling factors + self.mlp_rbf = Dense( + emb_size_rbf, + emb_size_edge, + activation=None, + bias=False, + dropout=dropout, + ) + self.scale_rbf = ScaleFactor() + + self.mlp_cbf = Dense( + emb_size_cbf, + emb_size_quad_in, + activation=None, + bias=False, + dropout=dropout, + ) + self.scale_cbf = ScaleFactor() + + self.mlp_sbf = EfficientInteractionBilinear( + emb_size_quad_in, + emb_size_sbf, + emb_size_quad_out, + dropout=dropout, + ) + self.scale_sbf_sum = ScaleFactor() + # combines scaling for bilinear layer and summation + + # Down and up projections + self.down_projection = Dense( + emb_size_edge, + emb_size_quad_in, + activation=activation, + bias=False, + dropout=dropout, + ) + self.up_projection_ca = Dense( + emb_size_quad_out, + emb_size_edge, + activation=activation, + bias=False, + dropout=dropout, + ) + if self.symmetric_mp: + self.up_projection_ac = Dense( + emb_size_quad_out, + emb_size_edge, + activation=activation, + bias=False, + dropout=dropout, + ) + + self.inv_sqrt_2 = 1 / math.sqrt(2.0) + + def forward( + self, + m, + bases, + idx, + id_swap, + ): + """ + Returns + ------- + m: torch.Tensor, shape=(nEdges, emb_size_edge) + Edge embeddings (c->a). + """ + + x_db = self.dense_db(m) # (nEdges, emb_size_edge) + + # Transform via radial basis + x_db2 = x_db * self.mlp_rbf(bases["rad"]) # (nEdges, emb_size_edge) + x_db = self.scale_rbf(x_db2, ref=x_db) + + # Down project embeddings + x_db = self.down_projection(x_db) # (nEdges, emb_size_quad_in) + + # Transform via circular basis + x_db = x_db[idx["triplet_in"]["in"]] + # (num_triplets_int, emb_size_quad_in) + + x_db2 = x_db * self.mlp_cbf(bases["cir"]) + # (num_triplets_int, emb_size_quad_in) + x_db = self.scale_cbf(x_db2, ref=x_db) + + # Transform via spherical basis + x_db = x_db[idx["trip_in_to_quad"]] + # (num_quadruplets, emb_size_quad_in) + x = self.mlp_sbf(bases["sph"], x_db, idx["out"], idx["out_agg"]) + # (nEdges, emb_size_quad_out) + x = self.scale_sbf_sum(x, ref=x_db) + + # => + # rbf(d_db) + # cbf(d_ba, angle_abd) + # sbf(d_ca, angle_cab, angle_cabd) + + if self.symmetric_mp: + # Upproject embeddings + x_ca = self.up_projection_ca(x) # (nEdges, emb_size_edge) + x_ac = self.up_projection_ac(x) # (nEdges, emb_size_edge) + + # Merge interaction of c->a and a->c + x_ac = x_ac[id_swap] # swap to add to edge a->c and not c->a + x_res = x_ca + x_ac + x_res = x_res * self.inv_sqrt_2 + return x_res + else: + x_res = self.up_projection_ca(x) + return x_res + + +class TripletInteraction(torch.nn.Module): + """ + Triplet-based message passing block. + + Arguments + --------- + emb_size_in: int + Embedding size of the input embeddings. + emb_size_out: int + Embedding size of the output embeddings. + emb_size_trip_in: int + (Down-projected) embedding size of the quadruplet edge embeddings + before the bilinear layer. + emb_size_trip_out: int + (Down-projected) embedding size of the quadruplet edge embeddings + after the bilinear layer. + emb_size_rbf: int + Embedding size of the radial basis transformation. + emb_size_cbf: int + Embedding size of the circular basis transformation (one angle). + symmetric_mp: bool + Whether to use symmetric message passing and + update the edges in both directions. + swap_output: bool + Whether to swap the output embedding directions. + Only relevant if symmetric_mp is False. + activation: str + Name of the activation function to use in the dense layers. + """ + + def __init__( + self, + emb_size_in, + emb_size_out, + emb_size_trip_in, + emb_size_trip_out, + emb_size_rbf, + emb_size_cbf, + symmetric_mp=True, + swap_output=True, + activation=None, + *, + dropout: float | None, + ): + super().__init__() + self.symmetric_mp = symmetric_mp + self.swap_output = swap_output + + # Dense transformation + self.dense_ba = Dense( + emb_size_in, + emb_size_in, + activation=activation, + bias=False, + dropout=dropout, + ) + + # Up projections of basis representations, bilinear layer and scaling factors + self.mlp_rbf = Dense( + emb_size_rbf, + emb_size_in, + activation=None, + bias=False, + dropout=dropout, + ) + self.scale_rbf = ScaleFactor() + + self.mlp_cbf = EfficientInteractionBilinear( + emb_size_trip_in, + emb_size_cbf, + emb_size_trip_out, + dropout=dropout, + ) + self.scale_cbf_sum = ScaleFactor() + # combines scaling for bilinear layer and summation + + # Down and up projections + self.down_projection = Dense( + emb_size_in, + emb_size_trip_in, + activation=activation, + bias=False, + dropout=dropout, + ) + self.up_projection_ca = Dense( + emb_size_trip_out, + emb_size_out, + activation=activation, + bias=False, + dropout=dropout, + ) + if self.symmetric_mp: + self.up_projection_ac = Dense( + emb_size_trip_out, + emb_size_out, + activation=activation, + bias=False, + dropout=dropout, + ) + + self.inv_sqrt_2 = 1 / math.sqrt(2.0) + + def forward( + self, + m, + bases, + idx, + id_swap, + expand_idx=None, + idx_agg2=None, + idx_agg2_inner=None, + agg2_out_size=None, + ): + """ + Returns + ------- + m: torch.Tensor, shape=(nEdges, emb_size_edge) + Edge embeddings. + """ + + # Dense transformation + x_ba = self.dense_ba(m) # (nEdges, emb_size_edge) + + if expand_idx is not None: + x_ba = x_ba[expand_idx] + + # Transform via radial basis + rad_emb = self.mlp_rbf(bases["rad"]) # (nEdges, emb_size_edge) + x_ba2 = x_ba * rad_emb + x_ba = self.scale_rbf(x_ba2, ref=x_ba) + + x_ba = self.down_projection(x_ba) # (nEdges, emb_size_trip_in) + + # Transform via circular spherical basis + x_ba = x_ba[idx["in"]] + + # Efficient bilinear layer + x = self.mlp_cbf( + basis=bases["cir"], + m=x_ba, + idx_agg_outer=idx["out"], + idx_agg_inner=idx["out_agg"], + idx_agg2_outer=idx_agg2, + idx_agg2_inner=idx_agg2_inner, + agg2_out_size=agg2_out_size, + ) + # (num_atoms, emb_size_trip_out) + x = self.scale_cbf_sum(x, ref=x_ba) + + # => + # rbf(d_ba) + # cbf(d_ca, angle_cab) + + if self.symmetric_mp: + # Up project embeddings + x_ca = self.up_projection_ca(x) # (nEdges, emb_size_edge) + x_ac = self.up_projection_ac(x) # (nEdges, emb_size_edge) + + # Merge interaction of c->a and a->c + x_ac = x_ac[id_swap] # swap to add to edge a->c and not c->a + x_res = x_ca + x_ac + x_res = x_res * self.inv_sqrt_2 + return x_res + else: + if self.swap_output: + x = x[id_swap] + x_res = self.up_projection_ca(x) # (nEdges, emb_size_edge) + return x_res + + +class PairInteraction(torch.nn.Module): + """ + Pair-based message passing block. + + Arguments + --------- + emb_size_atom: int + Embedding size of the atoms. + emb_size_pair_in: int + Embedding size of the atom pairs before the bilinear layer. + emb_size_pair_out: int + Embedding size of the atom pairs after the bilinear layer. + emb_size_rbf: int + Embedding size of the radial basis transformation. + activation: str + Name of the activation function to use in the dense layers. + """ + + def __init__( + self, + emb_size_atom, + emb_size_pair_in, + emb_size_pair_out, + emb_size_rbf, + activation=None, + *, + dropout: float | None, + ): + super().__init__() + + # Bilinear layer and scaling factor + self.bilinear = Dense( + emb_size_rbf * emb_size_pair_in, + emb_size_pair_out, + activation=None, + bias=False, + dropout=dropout, + ) + self.scale_rbf_sum = ScaleFactor() + + # Down and up projections + self.down_projection = Dense( + emb_size_atom, + emb_size_pair_in, + activation=activation, + bias=False, + dropout=dropout, + ) + self.up_projection = Dense( + emb_size_pair_out, + emb_size_atom, + activation=activation, + bias=False, + dropout=dropout, + ) + + self.inv_sqrt_2 = 1 / math.sqrt(2.0) + + def forward( + self, + h, + rad_basis, + edge_index, + target_neighbor_idx, + ): + """ + Returns + ------- + h: torch.Tensor, shape=(num_atoms, emb_size_atom) + Atom embeddings. + """ + num_atoms = h.shape[0] + + x_b = self.down_projection(h) # (num_atoms, emb_size_edge) + x_ba = x_b[edge_index[0]] # (num_edges, emb_size_edge) + + Kmax = torch.max(target_neighbor_idx) + 1 + x2 = x_ba.new_zeros(num_atoms, Kmax, x_ba.shape[-1]) + x2[edge_index[1], target_neighbor_idx] = x_ba + # (num_atoms, Kmax, emb_size_edge) + + x_ba2 = rad_basis @ x2 + # (num_atoms, emb_size_interm, emb_size_edge) + h_out = self.bilinear(x_ba2.reshape(num_atoms, -1)) + + h_out = self.scale_rbf_sum(h_out, ref=x_ba) + # (num_atoms, emb_size_edge) + + h_out = self.up_projection(h_out) # (num_atoms, emb_size_atom) + + return h_out diff --git a/src/jmp/models/gemnet/layers/radial_basis_dynamic_cutoff.py b/src/jmp/models/gemnet/layers/radial_basis_dynamic_cutoff.py new file mode 100644 index 0000000..0c9491f --- /dev/null +++ b/src/jmp/models/gemnet/layers/radial_basis_dynamic_cutoff.py @@ -0,0 +1,287 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math +from logging import getLogger +from typing import cast + +import numpy as np +import torch +import torch.nn as nn +from scipy.special import binom +from torch_geometric.data import Batch + +from ....modules.scaling import ScaleFactor + +log = getLogger(__name__) + + +class PolynomialEnvelope(nn.Module): + """ + Polynomial envelope function that ensures a smooth cutoff. + + Arguments + --------- + exponent: int + Exponent of the envelope function. + """ + + def __init__(self, exponent): + super().__init__() + assert exponent > 0 + self.p = exponent + self.a = -(self.p + 1) * (self.p + 2) / 2 + self.b = self.p * (self.p + 2) + self.c = -self.p * (self.p + 1) / 2 + + def forward(self, d_scaled): + env_val = ( + 1 + + self.a * d_scaled**self.p + + self.b * d_scaled ** (self.p + 1) + + self.c * d_scaled ** (self.p + 2) + ) + return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) + + +class ExponentialEnvelope(nn.Module): + """ + Exponential envelope function that ensures a smooth cutoff, + as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. + SpookyNet: Learning Force Fields with Electronic Degrees of Freedom + and Nonlocal Effects + """ + + def __init__(self): + super().__init__() + + def forward(self, d_scaled): + env_val = torch.exp(-(d_scaled**2) / ((1 - d_scaled) * (1 + d_scaled))) + return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) + + +class GaussianBasis(nn.Module): + def __init__( + self, + start=0.0, + stop=5.0, + num_gaussians=50, + trainable=False, + trainable_stds=False, + ): + super().__init__() + + self.trainable = trainable + self.trainable_stds = trainable_stds + + offset = torch.linspace(start, stop, num_gaussians) + if self.trainable: + _ = nn.init.uniform_(offset, start, stop) + self.offset = nn.Parameter(offset, requires_grad=True) + else: + self.register_buffer("offset", offset) + + if self.trainable_stds: + self.temps = nn.Parameter(torch.empty(num_gaussians), requires_grad=True) + temp_mean = -0.5 / ((stop - start) / (num_gaussians - 1)) ** 2 + _ = nn.init.normal_(self.temps, mean=temp_mean, std=abs(temp_mean) * 2.0) + else: + temp = 0.5 / ((stop - start) / (num_gaussians - 1)) ** 2 + self.register_buffer("temps", torch.full_like(offset, temp)) + + def forward(self, dist): + dist = dist[:, None] - self.offset[None, :] + return torch.exp(dist.square() * -(self.temps.abs())) + + +class SphericalBesselBasis(nn.Module): + """ + First-order spherical Bessel basis + + Arguments + --------- + num_radial: int + Number of basis functions. Controls the maximum frequency. + cutoff: float + Cutoff distance in Angstrom. + """ + + def __init__( + self, + num_radial: int, + cutoff: float, + ): + super().__init__() + self.norm_const = math.sqrt(2 / (cutoff**3)) + # cutoff ** 3 to counteract dividing by d_scaled = d / cutoff + + # Initialize frequencies at canonical positions + self.frequencies = nn.Parameter( + data=torch.tensor(np.pi * np.arange(1, num_radial + 1, dtype=np.float32)), + requires_grad=True, + ) + + def forward(self, d_scaled): + return ( + self.norm_const + / d_scaled[:, None] + * torch.sin(self.frequencies * d_scaled[:, None]) + ) # (num_edges, num_radial) + + +class BernsteinBasis(nn.Module): + """ + Bernstein polynomial basis, + as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. + SpookyNet: Learning Force Fields with Electronic Degrees of Freedom + and Nonlocal Effects + + Arguments + --------- + num_radial: int + Number of basis functions. Controls the maximum frequency. + pregamma_initial: float + Initial value of exponential coefficient gamma. + Default: gamma = 0.5 * a_0**-1 = 0.94486, + inverse softplus -> pregamma = log e**gamma - 1 = 0.45264 + """ + + def __init__( + self, + num_radial: int, + pregamma_initial: float = 0.45264, + ): + super().__init__() + prefactor = binom(num_radial - 1, np.arange(num_radial)) + self.register_buffer( + "prefactor", + torch.tensor(prefactor, dtype=torch.float), + persistent=False, + ) + + self.pregamma = nn.Parameter( + data=torch.tensor(pregamma_initial, dtype=torch.float), + requires_grad=True, + ) + self.softplus = nn.Softplus() + + exp1 = torch.arange(num_radial) + self.register_buffer("exp1", exp1[None, :], persistent=False) + exp2 = num_radial - 1 - exp1 + self.register_buffer("exp2", exp2[None, :], persistent=False) + + def forward(self, d_scaled): + gamma = self.softplus(self.pregamma) # constrain to positive + exp_d = torch.exp(-gamma * d_scaled)[:, None] + return self.prefactor * (exp_d**self.exp1) * ((1 - exp_d) ** self.exp2) + + +class RadialBasis(nn.Module): + """ + + Arguments + --------- + num_radial: int + Number of basis functions. Controls the maximum frequency. + cutoff: float + Cutoff distance in Angstrom. + rbf: dict = {"name": "gaussian"} + Basis function and its hyperparameters. + envelope: dict = {"name": "polynomial", "exponent": 5} + Envelope function and its hyperparameters. + scale_basis: bool + Whether to scale the basis values for better numerical stability. + """ + + def __init__( + self, + num_radial: int, + graph_type: str, + rbf: dict = {"name": "gaussian"}, + envelope: dict = {"name": "polynomial", "exponent": 5}, + scale_basis: bool = False, + *, + absolute_cutoff: float | None, + ): + super().__init__() + + self.absolute_cutoff = absolute_cutoff + if self.absolute_cutoff: + log.info( + f"[{self.__class__.__qualname__}] Using absolute cutoff of {self.absolute_cutoff} Angstroms." + ) + else: + log.info(f"[{self.__class__.__qualname__}] Using relative cutoff.") + + # self.inv_cutoff = 1 / cutoff + self.graph_type = graph_type + + self.scale_basis = scale_basis + if self.scale_basis: + self.scale_rbf = ScaleFactor() + + env_name = envelope["name"].lower() + env_hparams = envelope.copy() + del env_hparams["name"] + + if env_name == "polynomial": + self.envelope = PolynomialEnvelope(**env_hparams) + elif env_name == "exponential": + self.envelope = ExponentialEnvelope(**env_hparams) + else: + raise ValueError(f"Unknown envelope function '{env_name}'.") + + rbf_name = rbf["name"].lower() + rbf_hparams = rbf.copy() + del rbf_hparams["name"] + + # RBFs get distances scaled to be in [0, 1] + if rbf_name == "gaussian": + self.rbf = GaussianBasis( + start=0, stop=1, num_gaussians=num_radial, **rbf_hparams + ) + elif rbf_name == "spherical_bessel": + assert ( + absolute_cutoff is not None + ), "Spherical Bessel basis requires absolute cutoff." + self.rbf = SphericalBesselBasis( + num_radial=num_radial, cutoff=absolute_cutoff, **rbf_hparams + ) + elif rbf_name == "bernstein": + self.rbf = BernsteinBasis(num_radial=num_radial, **rbf_hparams) + else: + raise ValueError(f"Unknown radial basis function '{rbf_name}'.") + + @staticmethod + def _get_tensor(data: Batch, key: str): + value = getattr(data, key, None) + if value is None: + raise ValueError(f"Batch does not contain key '{key}'.") + if not torch.is_tensor(value): + raise ValueError(f"Key '{key}' must be a tensor.") + + value = cast(torch.Tensor, value) + return value + + def forward(self, d, *, data: Batch): + if not (cutoff := self.absolute_cutoff): + edge_index = self._get_tensor(data, f"{self.graph_type}_edge_index") + cutoff = self._get_tensor(data, f"{self.graph_type}_cutoff") # (b,) + cutoff = cutoff[data.batch] # (n_nodes,) + cutoff = cutoff[edge_index[0]] # (n_edges,) + + d_scaled = d / cutoff + + env = self.envelope(d_scaled) + res = env[:, None] * self.rbf(d_scaled) + + if self.scale_basis: + res = self.scale_rbf(res) + + return res + # (num_edges, num_radial) or (num_edges, num_orders * num_radial) diff --git a/src/jmp/models/gemnet/layers/spherical_basis_dynamic_cutoff.py b/src/jmp/models/gemnet/layers/spherical_basis_dynamic_cutoff.py new file mode 100644 index 0000000..28cc60d --- /dev/null +++ b/src/jmp/models/gemnet/layers/spherical_basis_dynamic_cutoff.py @@ -0,0 +1,139 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +import torch.nn as nn + +from ....modules.scaling import ScaleFactor +from .basis_utils import get_sph_harm_basis +from .radial_basis_dynamic_cutoff import GaussianBasis, RadialBasis + + +class CircularBasisLayer(nn.Module): + """ + 2D Fourier Bessel Basis + + Arguments + --------- + num_spherical: int + Number of basis functions. Controls the maximum frequency. + radial_basis: RadialBasis + Radial basis function. + cbf: dict + Name and hyperparameters of the circular basis function. + scale_basis: bool + Whether to scale the basis values for better numerical stability. + """ + + def __init__( + self, + num_spherical: int, + radial_basis: RadialBasis, + cbf: dict, + scale_basis: bool = False, + ): + super().__init__() + + self.radial_basis = radial_basis + + self.scale_basis = scale_basis + if self.scale_basis: + self.scale_cbf = ScaleFactor() + + cbf_name = cbf["name"].lower() + cbf_hparams = cbf.copy() + del cbf_hparams["name"] + + if cbf_name == "gaussian": + self.cosφ_basis = GaussianBasis( + start=-1, stop=1, num_gaussians=num_spherical, **cbf_hparams + ) + elif cbf_name == "spherical_harmonics": + self.cosφ_basis = get_sph_harm_basis(num_spherical, zero_m_only=True) + else: + raise ValueError(f"Unknown cosine basis function '{cbf_name}'.") + + def forward(self, D_ca, cosφ_cab, *, data): + rad_basis = self.radial_basis(D_ca, data=data) # (num_edges, num_radial) + cir_basis = self.cosφ_basis(cosφ_cab) # (num_triplets, num_spherical) + + if self.scale_basis: + cir_basis = self.scale_cbf(cir_basis) + + return rad_basis, cir_basis + # (num_edges, num_radial), (num_triplets, num_spherical) + + +class SphericalBasisLayer(nn.Module): + """ + 3D Fourier Bessel Basis + + Arguments + --------- + num_spherical: int + Number of basis functions. Controls the maximum frequency. + radial_basis: RadialBasis + Radial basis functions. + sbf: dict + Name and hyperparameters of the spherical basis function. + scale_basis: bool + Whether to scale the basis values for better numerical stability. + """ + + def __init__( + self, + num_spherical: int, + radial_basis: RadialBasis, + sbf: dict, + scale_basis: bool = False, + ): + super().__init__() + + self.num_spherical = num_spherical + self.radial_basis = radial_basis + + self.scale_basis = scale_basis + if self.scale_basis: + self.scale_sbf = ScaleFactor() + + sbf_name = sbf["name"].lower() + sbf_hparams = sbf.copy() + del sbf_hparams["name"] + + if sbf_name == "spherical_harmonics": + self.spherical_basis = get_sph_harm_basis(num_spherical, zero_m_only=False) + + elif sbf_name == "legendre_outer": + circular_basis = get_sph_harm_basis(num_spherical, zero_m_only=True) + self.spherical_basis = lambda cosφ, ϑ: ( + circular_basis(cosφ)[:, :, None] + * circular_basis(torch.cos(ϑ))[:, None, :] + ).reshape(cosφ.shape[0], -1) + + elif sbf_name == "gaussian_outer": + self.circular_basis = GaussianBasis( + start=-1, stop=1, num_gaussians=num_spherical, **sbf_hparams + ) + self.spherical_basis = lambda cosφ, ϑ: ( + self.circular_basis(cosφ)[:, :, None] + * self.circular_basis(torch.cos(ϑ))[:, None, :] + ).reshape(cosφ.shape[0], -1) + + else: + raise ValueError(f"Unknown spherical basis function '{sbf_name}'.") + + def forward(self, D_ca, cosφ_cab, θ_cabd, *, data): + rad_basis = self.radial_basis(D_ca, data=data) + sph_basis = self.spherical_basis(cosφ_cab, θ_cabd) + # (num_quadruplets, num_spherical**2) + + if self.scale_basis: + sph_basis = self.scale_sbf(sph_basis) + + return rad_basis, sph_basis + # (num_edges, num_radial), (num_quadruplets, num_spherical**2) diff --git a/src/jmp/models/gemnet/scale_files/large.pt b/src/jmp/models/gemnet/scale_files/large.pt new file mode 100644 index 0000000..3917aa6 Binary files /dev/null and b/src/jmp/models/gemnet/scale_files/large.pt differ diff --git a/src/jmp/models/gemnet/scale_files/small.pt b/src/jmp/models/gemnet/scale_files/small.pt new file mode 100644 index 0000000..1ceaa39 Binary files /dev/null and b/src/jmp/models/gemnet/scale_files/small.pt differ diff --git a/src/jmp/models/gemnet/utils.py b/src/jmp/models/gemnet/utils.py new file mode 100644 index 0000000..c2ce70c --- /dev/null +++ b/src/jmp/models/gemnet/utils.py @@ -0,0 +1,563 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch +from torch_scatter import segment_coo, segment_csr +from torch_sparse import SparseTensor + + +def get_max_neighbors_mask_tensor( + natoms: torch.Tensor, + index: torch.Tensor, + atom_distance: torch.Tensor, + max_num_neighbors_threshold: torch.Tensor, +): + """ + Give a mask that filters out edges so that each atom has at most + `max_num_neighbors_threshold` neighbors. + Assumes that `index` is sorted. + """ + device = natoms.device + num_atoms = int(natoms.sum()) + + # Get number of neighbors + # segment_coo assumes sorted index + ones = index.new_ones(1).expand_as(index) + num_neighbors = segment_coo(ones, index, dim_size=num_atoms) + max_num_neighbors = int(num_neighbors.max()) + num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) + + # Get number of (thresholded) neighbors per image + image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) + image_indptr[1:] = torch.cumsum(natoms, dim=0) + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + + # If max_num_neighbors is below the threshold, return early + if (max_num_neighbors_threshold >= max_num_neighbors).all(): + mask_num_neighbors = torch.ones_like(num_neighbors, dtype=torch.bool) + return mask_num_neighbors, num_neighbors_image + + # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. + # Fill with infinity so we can easily remove unused distances later. + distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) + + # Create an index map to map distances from atom_distance to distance_sort + # index_sort_map assumes index to be sorted + index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors + index_neighbor_offset_expand = torch.repeat_interleave( + index_neighbor_offset, num_neighbors + ) + index_sort_map = ( + index * max_num_neighbors + + torch.arange(len(index), device=device) + - index_neighbor_offset_expand + ) + distance_sort.index_copy_(0, index_sort_map, atom_distance) + distance_sort = distance_sort.view(num_atoms, max_num_neighbors) + + # Sort neighboring atoms based on distance + distance_sort, index_sort = torch.sort(distance_sort, dim=1) + # Select the max_num_neighbors_threshold neighbors that are closest + distance_sort = distance_sort[:, :max_num_neighbors_threshold] + index_sort = index_sort[:, :max_num_neighbors_threshold] + + # Offset index_sort so that it indexes into index + index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( + -1, max_num_neighbors_threshold + ) + # Remove "unused pairs" with infinite distances + mask_finite = torch.isfinite(distance_sort) + index_sort = torch.masked_select(index_sort, mask_finite) + + # At this point index_sort contains the index into index of the + # closest max_num_neighbors_threshold neighbors per atom + # Create a mask to remove all pairs not in index_sort + mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) + mask_num_neighbors.index_fill_(0, index_sort, True) + + return mask_num_neighbors, num_neighbors_image + + +def get_max_neighbors_mask(natoms, index, atom_distance, max_num_neighbors_threshold): + """ + Give a mask that filters out edges so that each atom has at most + `max_num_neighbors_threshold` neighbors. + Assumes that `index` is sorted. + """ + device = natoms.device + num_atoms = natoms.sum() + + # Get number of neighbors + # segment_coo assumes sorted index + ones = index.new_ones(1).expand_as(index) + num_neighbors = segment_coo(ones, index, dim_size=num_atoms) + max_num_neighbors = num_neighbors.max() + num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) + + # Get number of (thresholded) neighbors per image + image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) + image_indptr[1:] = torch.cumsum(natoms, dim=0) + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + + # If max_num_neighbors is below the threshold, return early + if ( + max_num_neighbors <= max_num_neighbors_threshold + or max_num_neighbors_threshold <= 0 + ): + mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as( + index + ) + return mask_num_neighbors, num_neighbors_image + + # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. + # Fill with infinity so we can easily remove unused distances later. + distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) + + # Create an index map to map distances from atom_distance to distance_sort + # index_sort_map assumes index to be sorted + index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors + index_neighbor_offset_expand = torch.repeat_interleave( + index_neighbor_offset, num_neighbors + ) + index_sort_map = ( + index * max_num_neighbors + + torch.arange(len(index), device=device) + - index_neighbor_offset_expand + ) + distance_sort.index_copy_(0, index_sort_map, atom_distance) + distance_sort = distance_sort.view(num_atoms, max_num_neighbors) + + # Sort neighboring atoms based on distance + distance_sort, index_sort = torch.sort(distance_sort, dim=1) + # Select the max_num_neighbors_threshold neighbors that are closest + distance_sort = distance_sort[:, :max_num_neighbors_threshold] + index_sort = index_sort[:, :max_num_neighbors_threshold] + + # Offset index_sort so that it indexes into index + index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( + -1, max_num_neighbors_threshold + ) + # Remove "unused pairs" with infinite distances + mask_finite = torch.isfinite(distance_sort) + index_sort = torch.masked_select(index_sort, mask_finite) + + # At this point index_sort contains the index into index of the + # closest max_num_neighbors_threshold neighbors per atom + # Create a mask to remove all pairs not in index_sort + mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) + mask_num_neighbors.index_fill_(0, index_sort, True) + + return mask_num_neighbors, num_neighbors_image + + +def ragged_range(sizes): + """Multiple concatenated ranges. + + Examples + -------- + sizes = [1 4 2 3] + Return: [0 0 1 2 3 0 1 0 1 2] + """ + assert sizes.dim() == 1 + if sizes.sum() == 0: + return sizes.new_empty(0) + + # Remove 0 sizes + sizes_nonzero = sizes > 0 + if not torch.all(sizes_nonzero): + sizes = torch.masked_select(sizes, sizes_nonzero) + + # Initialize indexing array with ones as we need to setup incremental indexing + # within each group when cumulatively summed at the final stage. + id_steps = torch.ones(sizes.sum(), dtype=torch.long, device=sizes.device) + id_steps[0] = 0 + insert_index = sizes[:-1].cumsum(0) + insert_val = (1 - sizes)[:-1] + + # Assign index-offsetting values + id_steps[insert_index] = insert_val + + # Finally index into input array for the group repeated o/p + res = id_steps.cumsum(0) + return res + + +def repeat_blocks( + sizes, + repeats, + continuous_indexing=True, + start_idx=0, + block_inc=0, + repeat_inc=0, +): + """Repeat blocks of indices. + Adapted from https://stackoverflow.com/questions/51154989/numpy-vectorized-function-to-repeat-blocks-of-consecutive-elements + + continuous_indexing: Whether to keep increasing the index after each block + start_idx: Starting index + block_inc: Number to increment by after each block, + either global or per block. Shape: len(sizes) - 1 + repeat_inc: Number to increment by after each repetition, + either global or per block + + Examples + -------- + sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = False + Return: [0 0 0 0 1 2 0 1 2 0 1 0 1 0 1] + sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True + Return: [0 0 0 1 2 3 1 2 3 4 5 4 5 4 5] + sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True ; + repeat_inc = 4 + Return: [0 4 8 1 2 3 5 6 7 4 5 8 9 12 13] + sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True ; + start_idx = 5 + Return: [5 5 5 6 7 8 6 7 8 9 10 9 10 9 10] + sizes = [1,3,2] ; repeats = [3,2,3] ; continuous_indexing = True ; + block_inc = 1 + Return: [0 0 0 2 3 4 2 3 4 6 7 6 7 6 7] + sizes = [0,3,2] ; repeats = [3,2,3] ; continuous_indexing = True + Return: [0 1 2 0 1 2 3 4 3 4 3 4] + sizes = [2,3,2] ; repeats = [2,0,2] ; continuous_indexing = True + Return: [0 1 0 1 5 6 5 6] + """ + assert sizes.dim() == 1 + assert all(sizes >= 0) + + # Remove 0 sizes + sizes_nonzero = sizes > 0 + if not torch.all(sizes_nonzero): + assert block_inc == 0 # Implementing this is not worth the effort + sizes = torch.masked_select(sizes, sizes_nonzero) + if isinstance(repeats, torch.Tensor): + repeats = torch.masked_select(repeats, sizes_nonzero) + if isinstance(repeat_inc, torch.Tensor): + repeat_inc = torch.masked_select(repeat_inc, sizes_nonzero) + + if isinstance(repeats, torch.Tensor): + assert all(repeats >= 0) + insert_dummy = repeats[0] == 0 + if insert_dummy: + one = sizes.new_ones(1) + zero = sizes.new_zeros(1) + sizes = torch.cat((one, sizes)) + repeats = torch.cat((one, repeats)) + if isinstance(block_inc, torch.Tensor): + block_inc = torch.cat((zero, block_inc)) + if isinstance(repeat_inc, torch.Tensor): + repeat_inc = torch.cat((zero, repeat_inc)) + else: + assert repeats >= 0 + insert_dummy = False + + # Get repeats for each group using group lengths/sizes + r1 = torch.repeat_interleave(torch.arange(len(sizes), device=sizes.device), repeats) + + # Get total size of output array, as needed to initialize output indexing array + N = (sizes * repeats).sum() + + # Initialize indexing array with ones as we need to setup incremental indexing + # within each group when cumulatively summed at the final stage. + # Two steps here: + # 1. Within each group, we have multiple sequences, so setup the offsetting + # at each sequence lengths by the seq. lengths preceding those. + id_ar = torch.ones(N, dtype=torch.long, device=sizes.device) + id_ar[0] = 0 + insert_index = sizes[r1[:-1]].cumsum(0) + insert_val = (1 - sizes)[r1[:-1]] + + if isinstance(repeats, torch.Tensor) and torch.any(repeats == 0): + diffs = r1[1:] - r1[:-1] + indptr = torch.cat((sizes.new_zeros(1), diffs.cumsum(0))) + if continuous_indexing: + # If a group was skipped (repeats=0) we need to add its size + insert_val += segment_csr(sizes[: r1[-1]], indptr, reduce="sum") + + # Add block increments + if isinstance(block_inc, torch.Tensor): + insert_val += segment_csr(block_inc[: r1[-1]], indptr, reduce="sum") + else: + insert_val += block_inc * (indptr[1:] - indptr[:-1]) + if insert_dummy: + insert_val[0] -= block_inc + else: + idx = r1[1:] != r1[:-1] + if continuous_indexing: + # 2. For each group, make sure the indexing starts from the next group's + # first element. So, simply assign 1s there. + insert_val[idx] = 1 + + # Add block increments + insert_val[idx] += block_inc + + # Add repeat_inc within each group + if isinstance(repeat_inc, torch.Tensor): + insert_val += repeat_inc[r1[:-1]] + if isinstance(repeats, torch.Tensor): + repeat_inc_inner = repeat_inc[repeats > 0][:-1] + else: + repeat_inc_inner = repeat_inc[:-1] + else: + insert_val += repeat_inc + repeat_inc_inner = repeat_inc + + # Subtract the increments between groups + if isinstance(repeats, torch.Tensor): + repeats_inner = repeats[repeats > 0][:-1] + else: + repeats_inner = repeats + insert_val[r1[1:] != r1[:-1]] -= repeat_inc_inner * repeats_inner + + # Assign index-offsetting values + id_ar[insert_index] = insert_val + + if insert_dummy: + id_ar = id_ar[1:] + if continuous_indexing: + id_ar[0] -= 1 + + # Set start index now, in case of insertion due to leading repeats=0 + id_ar[0] += start_idx + + # Finally index into input array for the group repeated o/p + res = id_ar.cumsum(0) + return res + + +def masked_select_sparsetensor_flat(src, mask): + row, col, value = src.coo() + row = row[mask] + col = col[mask] + value = value[mask] + return SparseTensor(row=row, col=col, value=value, sparse_sizes=src.sparse_sizes()) + + +def calculate_interatomic_vectors(R, id_s, id_t, offsets_st): + """ + Calculate the vectors connecting the given atom pairs, + considering offsets from periodic boundary conditions (PBC). + + Arguments + --------- + R: Tensor, shape = (nAtoms, 3) + Atom positions. + id_s: Tensor, shape = (nEdges,) + Indices of the source atom of the edges. + id_t: Tensor, shape = (nEdges,) + Indices of the target atom of the edges. + offsets_st: Tensor, shape = (nEdges,) + PBC offsets of the edges. + Subtract this from the correct direction. + + Returns + ------- + (D_st, V_st): tuple + D_st: Tensor, shape = (nEdges,) + Distance from atom t to s. + V_st: Tensor, shape = (nEdges,) + Unit direction from atom t to s. + """ + Rs = R[id_s] + Rt = R[id_t] + # ReLU prevents negative numbers in sqrt + if offsets_st is None: + V_st = Rt - Rs # s -> t + else: + V_st = Rt - Rs + offsets_st # s -> t + D_st = torch.sqrt(torch.sum(V_st**2, dim=1)) + V_st = V_st / D_st[..., None] + return D_st, V_st + + +def inner_product_clamped(x, y): + """ + Calculate the inner product between the given normalized vectors, + giving a result between -1 and 1. + """ + return torch.sum(x * y, dim=-1).clamp(min=-1, max=1) + + +def get_angle(R_ac, R_ab): + """Calculate angles between atoms c -> a <- b. + + Arguments + --------- + R_ac: Tensor, shape = (N, 3) + Vector from atom a to c. + R_ab: Tensor, shape = (N, 3) + Vector from atom a to b. + + Returns + ------- + angle_cab: Tensor, shape = (N,) + Angle between atoms c <- a -> b. + """ + # cos(alpha) = (u * v) / (|u|*|v|) + x = torch.sum(R_ac * R_ab, dim=-1) # shape = (N,) + # sin(alpha) = |u x v| / (|u|*|v|) + y = torch.cross(R_ac, R_ab, dim=-1).norm(dim=-1) # shape = (N,) + y = y.clamp(min=1e-9) # Avoid NaN gradient for y = (0,0,0) + + angle = torch.atan2(y, x) + return angle + + +def vector_rejection(R_ab, P_n): + """ + Project the vector R_ab onto a plane with normal vector P_n. + + Arguments + --------- + R_ab: Tensor, shape = (N, 3) + Vector from atom a to b. + P_n: Tensor, shape = (N, 3) + Normal vector of a plane onto which to project R_ab. + + Returns + ------- + R_ab_proj: Tensor, shape = (N, 3) + Projected vector (orthogonal to P_n). + """ + a_x_b = torch.sum(R_ab * P_n, dim=-1) + b_x_b = torch.sum(P_n * P_n, dim=-1) + return R_ab - (a_x_b / b_x_b)[:, None] * P_n + + +def get_projected_angle(R_ab, P_n, eps=1e-4): + """ + Project the vector R_ab onto a plane with normal vector P_n, + then calculate the angle w.r.t. the (x [cross] P_n), + or (y [cross] P_n) if the former would be ill-defined/numerically unstable. + + Arguments + --------- + R_ab: Tensor, shape = (N, 3) + Vector from atom a to b. + P_n: Tensor, shape = (N, 3) + Normal vector of a plane onto which to project R_ab. + eps: float + Norm of projection below which to use the y-axis instead of x. + + Returns + ------- + angle_ab: Tensor, shape = (N) + Angle on plane w.r.t. x- or y-axis. + """ + R_ab_proj = torch.cross(R_ab, P_n, dim=-1) + + # Obtain axis defining the angle=0 + x = P_n.new_tensor([[1, 0, 0]]).expand_as(P_n) + zero_angle = torch.cross(x, P_n, dim=-1) + + use_y = torch.norm(zero_angle, dim=-1) < eps + P_n_y = P_n[use_y] + y = P_n_y.new_tensor([[0, 1, 0]]).expand_as(P_n_y) + y_cross = torch.cross(y, P_n_y, dim=-1) + zero_angle[use_y] = y_cross + + angle = get_angle(zero_angle, R_ab_proj) + + # Flip sign of angle if necessary to obtain clock-wise angles + cross = torch.cross(zero_angle, R_ab_proj, dim=-1) + flip_sign = torch.sum(cross * P_n, dim=-1) < 0 + angle[flip_sign] = -angle[flip_sign] + + return angle + + +def mask_neighbors(neighbors, edge_mask): + neighbors_old_indptr = torch.cat([neighbors.new_zeros(1), neighbors]) + neighbors_old_indptr = torch.cumsum(neighbors_old_indptr, dim=0) + neighbors = segment_csr(edge_mask.long(), neighbors_old_indptr) + return neighbors + + +def get_neighbor_order(num_atoms, index, atom_distance): + """ + Give a mask that filters out edges so that each atom has at most + `max_num_neighbors_threshold` neighbors. + """ + device = index.device + + # Get sorted index and inverse sorting + # Necessary for index_sort_map + index_sorted, index_order = torch.sort(index) + index_order_inverse = torch.argsort(index_order) + + # Get number of neighbors + ones = index_sorted.new_ones(1).expand_as(index_sorted) + num_neighbors = segment_coo(ones, index_sorted, dim_size=num_atoms) + max_num_neighbors = num_neighbors.max() + + # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. + # Fill with infinity so we can easily remove unused distances later. + distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) + + # Create an index map to map distances from atom_distance to distance_sort + index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors + index_neighbor_offset_expand = torch.repeat_interleave( + index_neighbor_offset, num_neighbors + ) + index_sort_map = ( + index_sorted * max_num_neighbors + + torch.arange(len(index_sorted), device=device) + - index_neighbor_offset_expand + ) + distance_sort.index_copy_(0, index_sort_map, atom_distance) + distance_sort = distance_sort.view(num_atoms, max_num_neighbors) + + # Sort neighboring atoms based on distance + distance_sort, index_sort = torch.sort(distance_sort, dim=1) + + # Offset index_sort so that it indexes into index_sorted + index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( + -1, max_num_neighbors + ) + # Remove "unused pairs" with infinite distances + mask_finite = torch.isfinite(distance_sort) + index_sort = torch.masked_select(index_sort, mask_finite) + + # Create indices specifying the order in index_sort + order_peratom = torch.arange(max_num_neighbors, device=device)[None, :].expand_as( + mask_finite + ) + order_peratom = torch.masked_select(order_peratom, mask_finite) + + # Re-index to obtain order value of each neighbor in index_sorted + order = torch.zeros(len(index), device=device, dtype=torch.long) + order[index_sort] = order_peratom + + return order[index_order_inverse] + + +def get_inner_idx(idx, dim_size): + """ + Assign an inner index to each element (neighbor) with the same index. + For example, with idx=[0 0 0 1 1 1 1 2 2] this returns [0 1 2 0 1 2 3 0 1]. + These indices allow reshape neighbor indices into a dense matrix. + idx has to be sorted for this to work. + """ + ones = idx.new_ones(1).expand_as(idx) + # ones_cpu = ones.cpu() + # idx_cpu = idx.cpu() + # nn_cpu = segment_coo(ones.cpu(), idx.cpu(), dim_size=dim_size) + # print(ones_cpu, idx_cpu, nn_cpu, nn_cpu.shape, nn_cpu.sum()) + num_neighbors = segment_coo(ones, idx, dim_size=dim_size) + inner_idx = ragged_range(num_neighbors) + return inner_idx + + +def get_edge_id(edge_idx, cell_offsets, num_atoms): + cell_basis = cell_offsets.max() - cell_offsets.min() + 1 + cell_id = ( + (cell_offsets * cell_offsets.new_tensor([[1, cell_basis, cell_basis**2]])) + .sum(-1) + .long() + ) + edge_id = edge_idx[0] + edge_idx[1] * num_atoms + cell_id * num_atoms**2 + return edge_id diff --git a/src/jmp/modules/dataset/common.py b/src/jmp/modules/dataset/common.py new file mode 100644 index 0000000..d135315 --- /dev/null +++ b/src/jmp/modules/dataset/common.py @@ -0,0 +1,81 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import assert_never + +import numpy as np +import torch +from jmp.lightning import TypedConfig + +from . import dataset_transform as DT +from .dataset_typing import TDataset + + +class DatasetSampleNConfig(TypedConfig): + sample_n: int + """Number of samples to take from the dataset""" + + seed: int + """Seed for the random number generator used to sample the dataset""" + + +class DatasetAtomRefConfig(TypedConfig): + refs: dict[str, dict[int, float] | list[float] | np.ndarray | torch.Tensor] + """ + Reference values for each property. + + For each property, the references can be provided as: + - A dictionary with the atom index as the key and the reference value as the value + - A list with the reference values, `(max_atomic_number,)` + - A numpy array with the reference values, `(max_atomic_number,)` + - A torch tensor with the reference values, `(max_atomic_number,)` + """ + + +def _atomref_to_tensor( + value: dict[int, float] | list[float] | np.ndarray | torch.Tensor, +) -> torch.Tensor: + match value: + case dict(): + max_atomic_number = max(value.keys()) + tensor = torch.zeros(max_atomic_number + 1) + for key, val in value.items(): + tensor[key] = val + return tensor + case list(): + return torch.tensor(value) + case np.ndarray(): + return torch.tensor(value) + case torch.Tensor(): + return value + case _: + assert_never(value) + + +class CommonDatasetConfig(TypedConfig): + sample_n: DatasetSampleNConfig | None = None + """Sample n samples from the dataset""" + + atom_ref: DatasetAtomRefConfig | None = None + """Configuration for referencing methods for atoms""" + + +def wrap_common_dataset(dataset: TDataset, config: CommonDatasetConfig) -> TDataset: + if (sample_n := config.sample_n) is not None: + dataset = DT.sample_n_transform( + dataset, + n=sample_n.sample_n, + seed=sample_n.seed, + ) + + if (atom_ref := config.atom_ref) is not None: + # Covnert the refs to a dict of tensors + refs = {key: _atomref_to_tensor(value) for key, value in atom_ref.refs.items()} + dataset = DT.atomref_transform(dataset, refs) + + return dataset diff --git a/src/jmp/modules/dataset/concat_dataset.py b/src/jmp/modules/dataset/concat_dataset.py new file mode 100644 index 0000000..4086eed --- /dev/null +++ b/src/jmp/modules/dataset/concat_dataset.py @@ -0,0 +1,301 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import cache, partial +from logging import getLogger +from typing import Generic, Literal, cast + +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange +from jmp.lightning import TypedConfig +from torch.utils.data import ConcatDataset +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ..metadata import post_create_dataset +from . import dataset_transform as DT +from .dataset_transform import expand_dataset +from .dataset_typing import DatasetProtocol, TDataset + +log = getLogger(__name__) + + +def _update_graph_value(data: BaseData, key: str, onehot: torch.Tensor): + assert (value := getattr(data, key, None)) is not None, f"{key} must be defined." + if not torch.is_tensor(value): + value = torch.tensor(value, dtype=torch.float) + + value = cast(torch.Tensor, value) + value = rearrange(value.view(-1), "1 -> 1 1") * onehot + setattr(data, key, value) + + +def _update_node_value(data: BaseData, key: str, onehot: torch.Tensor): + assert (value := getattr(data, key, None)) is not None, f"{key} must be defined." + assert torch.is_tensor(value), f"{key} must be a tensor." + value = cast(torch.Tensor, value) + value = rearrange(value, "n ... -> n ... 1") * onehot + setattr(data, key, value) + + +def _update_task_idx_transform( + data: BaseData, + *, + task_idx: int, + num_tasks: int, + taskify_keys_graph: list[str] = ["y"], + taskify_keys_node: list[str] = ["force"], + use_onehot: bool = True, +): + data.task_idx = torch.tensor(task_idx, dtype=torch.long) + # set one-hot vector + onehot: torch.Tensor = F.one_hot( + data.task_idx, num_classes=num_tasks + ).bool() # (t,) + taskify_onehot = onehot if use_onehot else torch.ones_like(onehot, dtype=torch.bool) + + # set task boolean mask + data.task_mask = rearrange(onehot, "t -> 1 t") + + # update graph-level attrs to be a one-hot vector * attr + for key in taskify_keys_graph: + _update_graph_value(data, key, taskify_onehot) + if f"{key}_norm_mean" in data: + _update_graph_value(data, f"{key}_norm_mean", taskify_onehot) + if f"{key}_norm_std" in data: + _update_graph_value(data, f"{key}_norm_std", taskify_onehot) + + # update node-level attrs to be a one-hot vector * attr + for key in taskify_keys_node: + _update_node_value(data, key, taskify_onehot) + if f"{key}_norm_mean" in data: + _update_node_value(data, f"{key}_norm_mean", taskify_onehot) + if f"{key}_norm_std" in data: + _update_node_value(data, f"{key}_norm_std", taskify_onehot) + + return data + + +class _MTConcatDataset(ConcatDataset[BaseData], Generic[TDataset]): + """ + Small wrapper around `ConcatDataset` which handles the concatenation of + `atoms_metadata` properly. `atoms_metadata` stores the number of atoms in + each molecule in the dataset and is used for balancing the batches + during runtime without having to load the entire molecule into memory. + """ + + datasets: list[TDataset] + + @override + def __init__( + self, + datasets: list[TDataset], + *, + taskify_keys_graph: list[str], + taskify_keys_node: list[str], + num_tasks: int, + task_idxs: list[int], + taskify_use_onehot: bool = True, + ) -> None: + datasets = list(datasets) + + datasets = [ + DT.transform( + dataset, + partial( + _update_task_idx_transform, + taskify_keys_graph=taskify_keys_graph, + taskify_keys_node=taskify_keys_node, + num_tasks=num_tasks, + task_idx=task_idx, + use_onehot=taskify_use_onehot, + ), + ) + for task_idx, dataset in zip(task_idxs, datasets) + ] + super().__init__(datasets) + + for dataset in self.datasets: + if not isinstance(dataset, DatasetProtocol): + raise TypeError( + f"Expected dataset to be an instance of DatasetProtocol, " + f"but got {dataset.__class__.__qualname__}." + ) + + if len(dataset.atoms_metadata) != len(dataset): + raise ValueError( + f"Expected atoms_metadata to have the same length as the dataset, " + f"but got {len(dataset.atoms_metadata)=} and {len(dataset)=}." + ) + + @property + @cache + def atoms_metadata(self) -> np.ndarray: + return np.concatenate([d.atoms_metadata for d in self.datasets]) + + def data_sizes(self, indices: list[int]) -> np.ndarray: + return self.atoms_metadata[indices] + + +class MTDatasetConfig(TypedConfig): + balanced: bool | None = None + strict: bool = True + + taskify_keys_graph: list[str] = ["y", "y_scale", "force_scale"] + """Converts the graph-level attributes to a one-hot vector * attr.""" + taskify_keys_node: list[str] = ["force"] + """Converts the node-level attributes to a one-hot vector * attr.""" + taskify_use_onehot: bool = True + """If True, the one-hot vector is used. If False, a vector of ones is used. (Should be True for most cases.)""" + + sample_type: Literal["uniform", "temperature"] | None = None + """ + The type of sampling to use for the datasets. + If `None`, the value of `balanced` will be used to determine the sampling type: + - If `balanced` is `True`, `sample_type` will be set to "uniform". + - If `balanced` is `False`, `sample_type` will be set to "temperature". + """ + sample_temperature: float | None = 1.0 + """ + The temperature to use for temperature sampling. + If `None`, the temperature will be set to 1.0. + """ + + +def _uniform_sampling(dataset_sizes: list[int]): + return [1.0] * len(dataset_sizes) + + +def _temperature_sampling(dataset_sizes: list[int], temp: float): + total_size = sum(dataset_sizes) # 3.25 mil + return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] + + +def merged_dataset(dataset_sizes_list: list[int], ratios_list: list[float]): + dataset_sizes = np.array(dataset_sizes_list) + ratios = np.array(ratios_list) + + # Calculate the target size of the final dataset + target_size = sum(dataset_sizes) / sum(ratios) + + # Calculate the minimum expansion factor for each dataset + expansion_factors = target_size * ratios / dataset_sizes + + # Make sure that the expansion factors are all at least 1.0 + expansion_factors = expansion_factors / np.min(expansion_factors) + + # Calculate the number of samples to take from each dataset + samples_per_dataset = np.ceil( + dataset_sizes * (expansion_factors / np.min(expansion_factors)) + ).astype(int) + + samples_per_dataset = cast(list[int], samples_per_dataset.tolist()) + return samples_per_dataset + + +class MTSampledDataset(_MTConcatDataset[TDataset], Generic[TDataset]): + """ + Takes a list of datasets, and scales the loss weights of each dataset by + the number of graphs and nodes in the dataset. This is useful for combining + datasets with different numbers of graphs and nodes. + """ + + @override + def __init__( + self, + datasets: list[TDataset], + config: MTDatasetConfig, + *, + num_tasks: int | None = None, + task_idxs: list[int] | None = None, + ignore_balancing: bool = False, + ) -> None: + if num_tasks is None: + num_tasks = len(datasets) + + if task_idxs is None: + task_idxs = list(range(num_tasks)) + + sample_type = config.sample_type + sample_temperature = config.sample_temperature + + if sample_type is None: + if config.balanced is None: + raise ValueError( + "Either `sample_type` or `balanced` must be specified in `MTSampledDataset.__init__`." + ) + + if config.balanced: + sample_type = "uniform" + else: + sample_type = "temperature" + sample_temperature = 1.0 + + log.critical( + f"Using {sample_type=} and {sample_temperature=} because " + f"`sample_type` is None and `balanced` is {config.balanced}." + ) + + assert ( + sample_type + in { + "uniform", + "temperature", + } + ), f"{config.sample_type=} must be one of 'balanced', 'uniform', or 'temperature'." + + if ignore_balancing: + log.critical( + "Ignoring balancing because `ignore_balancing` is True in `MTSampledDataset.__init__`." + ) + sample_type = "temperature" + sample_temperature = 1.0 + + match sample_type: + case "uniform": + # we use uniform sampling (i.e., we sample each dataset with equal probability) + ratios = _uniform_sampling([len(d) for d in datasets]) + case "temperature": + assert ( + sample_temperature is not None + ), "sample_temperature must be specified if sample_type is 'temperature'." + ratios = _temperature_sampling( + [len(d) for d in datasets], sample_temperature + ) + case _: + raise ValueError(f"{sample_type=} is not a valid sampling type.") + + ratios = [r / sum(ratios) for r in ratios] + log.info(f"Using {ratios=} for {sample_type=}.") + + expanded_dataset_sizes = merged_dataset([len(d) for d in datasets], ratios) + datasets = [ + expand_dataset(dataset, n=n) + for dataset, n in zip(datasets, expanded_dataset_sizes) + ] + + super().__init__( + datasets, + taskify_keys_graph=config.taskify_keys_graph, + taskify_keys_node=config.taskify_keys_node, + num_tasks=num_tasks, + task_idxs=task_idxs, + taskify_use_onehot=config.taskify_use_onehot, + ) + + post_create_dataset(self, strict=config.strict) + + def representative_batch_for_testing(self, *, n: int = 1, start_index: int = 0): + data_list = [ + cast(BaseData, dataset[index]) + for dataset in self.datasets + for index in range(start_index, start_index + n) + ] + return data_list diff --git a/src/jmp/modules/dataset/dataset_transform.py b/src/jmp/modules/dataset/dataset_transform.py new file mode 100644 index 0000000..3686dd7 --- /dev/null +++ b/src/jmp/modules/dataset/dataset_transform.py @@ -0,0 +1,269 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from collections import abc +from collections.abc import Callable +from functools import cache, partial +from logging import getLogger +from typing import Any, cast + +import numpy as np +import torch +import wrapt +from typing_extensions import override + +from .. import transforms as T +from .dataset_typing import TDataset + +log = getLogger(__name__) + + +def transform( + dataset: TDataset, + transform: Callable[[Any], Any], + copy_data: bool = True, +) -> TDataset: + """ + Applies a transformation/mapping function to all elements of the dataset. + + Args: + dataset (Dataset): The dataset to transform. + transform (Callable): The transformation function. + copy_data (bool, optional): Whether to copy the data before transforming. Defaults to True. + """ + + class _TransformedDataset(wrapt.ObjectProxy): + @override + def __getitem__(self, idx): + nonlocal copy_data, transform + + assert transform is not None, "Transform must be defined." + data = self.__wrapped__.__getitem__(idx) + if copy_data: + data = copy.deepcopy(data) + data = transform(data) + return data + + return cast(TDataset, _TransformedDataset(dataset)) + + +def atomref_transform( + dataset: TDataset, + refs: dict[str, torch.Tensor], + keep_raw: bool = False, +) -> TDataset: + """ + Subtracts the atomrefs from the target properties of the dataset. For a data sample x and atomref property p, + the transformed property is `x[p] = x[p] - atomref[x.atomic_numbers].sum()`. + + This is primarily used to normalize energies using a "linear referencing" scheme. + + Args: + dataset (Dataset): The dataset to transform. + refs (dict[str, torch.Tensor]): The atomrefs to subtract from the target properties. + keep_raw (bool, optional): Whether to keep the original properties, renamed as `{target}_raw`. Defaults to False. + """ + # Convert the refs to tensors + refs_dict: dict[str, torch.Tensor] = {} + for k, v in refs.items(): + if isinstance(v, list): + v = torch.tensor(v) + elif isinstance(v, np.ndarray): + v = torch.from_numpy(v).float() + elif not torch.is_tensor(v): + raise TypeError(f"Invalid type for {k} in atomrefs: {type(v)}") + refs_dict[k] = v + + return transform( + dataset, + partial(T.atomref_transform, refs=refs_dict, keep_raw=keep_raw), + copy_data=False, + ) + + +def expand_dataset(dataset: TDataset, n: int) -> TDataset: + """ + Expands the dataset to have `n` elements by repeating the elements of the dataset as many times as necessary. + + Args: + dataset (Dataset): The dataset to expand. + n (int): The desired length of the dataset. + """ + if not isinstance(dataset, abc.Sized): + raise TypeError( + f"expand_dataset ({n}) must be used with a dataset that is an instance of abc.Sized " + f"for {dataset.__class__.__qualname__} " + ) + + og_size = len(dataset) + if og_size > n: + raise ValueError( + f"expand_dataset ({n}) must be greater than or equal to the length of the dataset " + f"({len(dataset)}) for {dataset.__class__.__qualname__}" + ) + + class _ExpandedDataset(wrapt.ObjectProxy): + @override + def __len__(self): + nonlocal n + return n + + @override + def __getitem__(self, index: int): + nonlocal n, og_size + if index < 0 or index >= n: + raise IndexError( + f"Index {index} is out of bounds for dataset of size {n}." + ) + return self.__wrapped__.__getitem__(index % og_size) + + @cache + def _atoms_metadata_cached(self): + """ + We want to retrieve the atoms metadata for the expanded dataset. + This includes repeating the atoms metadata for the elemens that are repeated. + """ + + # the out metadata shape should be (n,) + nonlocal n, og_size + + metadata = self.__wrapped__.atoms_metadata + metadata = np.resize(metadata, (n,)) + log.debug( + f"Expanded the atoms metadata for {self.__class__.__name__} ({og_size} => {len(metadata)})." + ) + return metadata + + @property + def atoms_metadata(self): + return self._atoms_metadata_cached() + + dataset = cast(TDataset, _ExpandedDataset(dataset)) + log.info(f"Expanded dataset {dataset.__class__.__name__} from {og_size} to {n}.") + return dataset + + +def first_n_transform(dataset: TDataset, *, n: int) -> TDataset: + """ + Returns a new dataset that contains the first `n` elements of the original dataset. + + Args: + dataset (Dataset): The dataset to transform. + n (int): The number of elements to keep. + """ + if not isinstance(dataset, abc.Sized): + raise TypeError( + f"first_n ({n}) must be used with a dataset that is an instance of abc.Sized " + f"for {dataset.__class__.__qualname__} " + ) + + if len(dataset) < n: + raise ValueError( + f"first_n ({n}) must be less than or equal to the length of the dataset " + f"({len(dataset)}) for {dataset.__class__.__qualname__} " + ) + + class _FirstNDataset(wrapt.ObjectProxy): + @override + def __getitem__(self, idx: int): + nonlocal n + + if idx < 0 or idx >= n: + raise IndexError( + f"Index {idx} is out of bounds for dataset of size {n}." + ) + + return self.__wrapped__.__getitem__(idx) + + @override + def __len__(self): + nonlocal n + return n + + @cache + def _atoms_metadata_cached(self): + """We only want to retrieve the atoms metadata for the first n elements.""" + nonlocal n + + metadata = self.__wrapped__.atoms_metadata + og_size = len(metadata) + metadata = metadata[:n] + + log.info( + f"Retrieved the first {n} atoms metadata for {self.__class__.__name__} ({og_size} => {len(metadata)})." + ) + return metadata + + @property + def atoms_metadata(self): + return self._atoms_metadata_cached() + + return cast(TDataset, _FirstNDataset(dataset)) + + +def sample_n_transform(dataset: TDataset, *, n: int, seed: int) -> TDataset: + """ + Similar to first_n_transform, but samples n elements randomly from the dataset. + + Args: + dataset (Dataset): The dataset to transform. + n (int): The number of elements to sample. + seed (int): The random seed to use for sampling. + """ + + if not isinstance(dataset, abc.Sized): + raise TypeError( + f"sample_n ({n}) must be used with a dataset that is an instance of abc.Sized " + f"for {dataset.__class__.__qualname__} " + ) + + if len(dataset) < n: + raise ValueError( + f"sample_n ({n}) must be less than or equal to the length of the dataset " + f"({len(dataset)}) for {dataset.__class__.__qualname__} " + ) + + sampled_indices = np.random.default_rng(seed).choice(len(dataset), n, replace=False) + + class _SampleNDataset(wrapt.ObjectProxy): + @override + def __getitem__(self, idx: int): + nonlocal n, sampled_indices + + if idx < 0 or idx >= n: + raise IndexError( + f"Index {idx} is out of bounds for dataset of size {n}." + ) + + return self.__wrapped__.__getitem__(sampled_indices[idx]) + + @override + def __len__(self): + nonlocal n + return n + + @cache + def _atoms_metadata_cached(self): + """We only want to retrieve the atoms metadata for the sampled n elements.""" + nonlocal n, sampled_indices + + metadata = self.__wrapped__.atoms_metadata + og_size = len(metadata) + metadata = metadata[sampled_indices] + + log.info( + f"Retrieved the sampled {n} atoms metadata for {self.__class__.__name__} ({og_size} => {len(metadata)})." + ) + return metadata + + @property + def atoms_metadata(self): + return self._atoms_metadata_cached() + + return cast(TDataset, _SampleNDataset(dataset)) diff --git a/src/jmp/modules/dataset/dataset_typing.py b/src/jmp/modules/dataset/dataset_typing.py new file mode 100644 index 0000000..15e45b1 --- /dev/null +++ b/src/jmp/modules/dataset/dataset_typing.py @@ -0,0 +1,26 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Protocol, runtime_checkable + +import numpy as np +from torch_geometric.data.data import BaseData +from typing_extensions import TypeVar + + +@runtime_checkable +class DatasetProtocol(Protocol): + @property + def atoms_metadata(self) -> np.ndarray: ... + + def __getitem__(self, index: int, /) -> BaseData: ... + + def __len__(self) -> int: ... + + +TDataset = TypeVar("TDataset", bound=DatasetProtocol, infer_variance=True) diff --git a/src/jmp/modules/early_stopping.py b/src/jmp/modules/early_stopping.py new file mode 100644 index 0000000..140999c --- /dev/null +++ b/src/jmp/modules/early_stopping.py @@ -0,0 +1,120 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math +from logging import getLogger + +from lightning.fabric.utilities.rank_zero import _get_rank +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import EarlyStopping as LightningEarlyStopping +from lightning.pytorch.utilities.rank_zero import rank_prefixed_message +from typing_extensions import override + +log = getLogger(__name__) + + +class EarlyStoppingWithMinLR(LightningEarlyStopping): + def __init__( + self, + monitor: str, + min_delta: float = 0, + min_lr: float | None = None, + patience: int = 3, + verbose: bool = True, + mode: str = "min", + strict: bool = True, + check_finite: bool = True, + stopping_threshold: float | None = None, + divergence_threshold: float | None = None, + check_on_train_epoch_end: bool | None = None, + log_rank_zero_only: bool = False, + ): + super().__init__( + monitor, + min_delta, + patience, + verbose, + mode, + strict, + check_finite, + stopping_threshold, + divergence_threshold, + check_on_train_epoch_end, + log_rank_zero_only, + ) + + self.min_lr = min_lr + + @override + @staticmethod + def _log_info( + trainer: Trainer | None, message: str, log_rank_zero_only: bool + ) -> None: + rank = _get_rank() + if trainer is not None and trainer.world_size <= 1: + rank = None + message = rank_prefixed_message(message, rank) + if rank is None or not log_rank_zero_only or rank == 0: + log.critical(message) + + @override + def _run_early_stopping_check(self, trainer: Trainer): + """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" + logs = trainer.callback_metrics + + # Disable early_stopping with fast_dev_run + if getattr(trainer, "fast_dev_run", False): + return + + should_stop, reason = False, None + + if not should_stop: + should_stop, reason = self._evaluate_stopping_criteria_min_lr(trainer) + + # If metric present + if not should_stop and self._validate_condition_metric(logs): + current = logs[self.monitor].squeeze() + should_stop, reason = self._evaluate_stopping_criteria(current) + + # stop every ddp process if any world process decides to stop + should_stop = trainer.strategy.reduce_boolean_decision(should_stop, all=False) + trainer.should_stop = trainer.should_stop or should_stop + if should_stop: + self.stopped_epoch = trainer.current_epoch + if reason and self.verbose: + self._log_info(trainer, reason, self.log_rank_zero_only) + + def _evaluate_stopping_criteria_min_lr( + self, trainer: Trainer + ) -> tuple[bool, str | None]: + if self.min_lr is None: + return False, None + + # Get the maximum LR across all param groups in all optimizers + model_max_lr = max( + [ + param_group["lr"] + for optimizer in trainer.optimizers + for param_group in optimizer.param_groups + ] + ) + if not isinstance(model_max_lr, float) or not math.isfinite(model_max_lr): + return False, None + + # If the maximum LR is less than the minimum LR, stop training + if model_max_lr >= self.min_lr: + return False, None + + return True, ( + "Stopping threshold reached: " + f"The maximum LR of the model across all param groups is {model_max_lr:.2e} " + f"which is less than the minimum LR {self.min_lr:.2e}" + ) + + def on_early_stopping(self, trainer: Trainer): + pass diff --git a/src/jmp/modules/ema.py b/src/jmp/modules/ema.py new file mode 100644 index 0000000..71e58d1 --- /dev/null +++ b/src/jmp/modules/ema.py @@ -0,0 +1,362 @@ +import contextlib +import copy +import threading +from typing import Iterable + +import lightning.pytorch as pl +import torch +from jmp.lightning import TypedConfig +from lightning.pytorch import Callback +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from typing_extensions import override + + +class EMA(Callback): + """ + Implements Exponential Moving Averaging (EMA). + + When training a model, this callback will maintain moving averages of the trained parameters. + When evaluating, we use the moving averages copy of the trained parameters. + When saving, we save an additional set of parameters with the prefix `ema`. + + Args: + decay: The exponential decay used when calculating the moving average. Has to be between 0-1. + validate_original_weights: Validate the original weights, as apposed to the EMA weights. + every_n_steps: Apply EMA every N steps. + cpu_offload: Offload weights to CPU. + """ + + @override + def __init__( + self, + decay: float, + validate_original_weights: bool = False, + every_n_steps: int = 1, + cpu_offload: bool = False, + ): + if not (0 <= decay <= 1): + raise MisconfigurationException("EMA decay value must be between 0 and 1") + self.decay = decay + self.validate_original_weights = validate_original_weights + self.every_n_steps = every_n_steps + self.cpu_offload = cpu_offload + + @override + def on_fit_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + device = pl_module.device if not self.cpu_offload else torch.device("cpu") + trainer.optimizers = [ + EMAOptimizer( + optim, + device=device, + decay=self.decay, + every_n_steps=self.every_n_steps, + current_step=trainer.global_step, + ) + for optim in trainer.optimizers + if not isinstance(optim, EMAOptimizer) + ] + + @override + def on_validation_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + @override + def on_validation_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + @override + def on_test_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + @override + def on_test_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self._should_validate_ema_weights(trainer): + self.swap_model_weights(trainer) + + def _should_validate_ema_weights(self, trainer: "pl.Trainer") -> bool: + return not self.validate_original_weights and self._ema_initialized(trainer) + + def _ema_initialized(self, trainer: "pl.Trainer") -> bool: + return any( + isinstance(optimizer, EMAOptimizer) for optimizer in trainer.optimizers + ) + + def swap_model_weights(self, trainer: "pl.Trainer", saving_ema_model: bool = False): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.switch_main_parameter_weights(saving_ema_model) + + @contextlib.contextmanager + def save_ema_model(self, trainer: "pl.Trainer"): + """ + Saves an EMA copy of the model + EMA optimizer states for resume. + """ + self.swap_model_weights(trainer, saving_ema_model=True) + try: + yield + finally: + self.swap_model_weights(trainer, saving_ema_model=False) + + @contextlib.contextmanager + def save_original_optimizer_state(self, trainer: "pl.Trainer"): + for optimizer in trainer.optimizers: + assert isinstance(optimizer, EMAOptimizer) + optimizer.save_original_optimizer_state = True + try: + yield + finally: + for optimizer in trainer.optimizers: + optimizer.save_original_optimizer_state = False + + +@torch.no_grad() +def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, + current_model_tuple, + alpha=(1.0 - decay), + ) + + +def run_ema_update_cpu( + ema_model_tuple, current_model_tuple, decay, pre_sync_stream=None +): + if pre_sync_stream is not None: + pre_sync_stream.synchronize() + + ema_update(ema_model_tuple, current_model_tuple, decay) + + +class EMAOptimizer(torch.optim.Optimizer): + r""" + EMAOptimizer is a wrapper for torch.optim.Optimizer that computes + Exponential Moving Average of parameters registered in the optimizer. + + EMA parameters are automatically updated after every step of the optimizer + with the following formula: + + ema_weight = decay * ema_weight + (1 - decay) * training_weight + + To access EMA parameters, use ``swap_ema_weights()`` context manager to + perform a temporary in-place swap of regular parameters with EMA + parameters. + + Notes: + - EMAOptimizer is not compatible with APEX AMP O2. + + Args: + optimizer (torch.optim.Optimizer): optimizer to wrap + device (torch.device): device for EMA parameters + decay (float): decay factor + + Returns: + returns an instance of torch.optim.Optimizer that computes EMA of + parameters + + Example: + model = Model().to(device) + opt = torch.optim.Adam(model.parameters()) + + opt = EMAOptimizer(opt, device, 0.9999) + + for epoch in range(epochs): + training_loop(model, opt) + + regular_eval_accuracy = evaluate(model) + + with opt.swap_ema_weights(): + ema_eval_accuracy = evaluate(model) + """ + + @override + def __init__( + self, + optimizer: torch.optim.Optimizer, + device: torch.device, + decay: float = 0.9999, + every_n_steps: int = 1, + current_step: int = 0, + ): + self.optimizer = optimizer + self.decay = decay + self.device = device + self.current_step = current_step + self.every_n_steps = every_n_steps + self.save_original_optimizer_state = False + + self.first_iteration = True + self.rebuild_ema_params = True + self.stream = None + self.thread = None + + self.ema_params = () + self.in_saving_ema_model_context = False + + def all_parameters(self) -> Iterable[torch.Tensor]: + return (param for group in self.param_groups for param in group["params"]) + + @override + def step(self, closure=None, **kwargs): + self.join() + + if self.first_iteration: + if any(p.is_cuda for p in self.all_parameters()): + self.stream = torch.cuda.Stream() + + self.first_iteration = False + + if self.rebuild_ema_params: + opt_params = list(self.all_parameters()) + + self.ema_params += tuple( + copy.deepcopy(param.data.detach()).to(self.device) + for param in opt_params[len(self.ema_params) :] + ) + self.rebuild_ema_params = False + + loss = self.optimizer.step(closure) + + if self._should_update_at_step(): + self.update() + self.current_step += 1 + return loss + + def _should_update_at_step(self) -> bool: + return self.current_step % self.every_n_steps == 0 + + @torch.no_grad() + def update(self): + if self.stream is not None: + self.stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.stream): + current_model_state = tuple( + param.data.to(self.device, non_blocking=True) + for param in self.all_parameters() + ) + + if self.device.type == "cuda": + ema_update(self.ema_params, current_model_state, self.decay) + + if self.device.type == "cpu": + self.thread = threading.Thread( + target=run_ema_update_cpu, + args=( + self.ema_params, + current_model_state, + self.decay, + self.stream, + ), + ) + self.thread.start() + + def swap_tensors(self, tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) + + def switch_main_parameter_weights(self, saving_ema_model: bool = False): + self.join() + self.in_saving_ema_model_context = saving_ema_model + for param, ema_param in zip(self.all_parameters(), self.ema_params): + self.swap_tensors(param.data, ema_param) + + @contextlib.contextmanager + def swap_ema_weights(self, enabled: bool = True): + r""" + A context manager to in-place swap regular parameters with EMA + parameters. + It swaps back to the original regular parameters on context manager + exit. + + Args: + enabled (bool): whether the swap should be performed + """ + + if enabled: + self.switch_main_parameter_weights() + try: + yield + finally: + if enabled: + self.switch_main_parameter_weights() + + def __getattr__(self, name): + return getattr(self.optimizer, name) + + def join(self): + if self.stream is not None: + self.stream.synchronize() + + if self.thread is not None: + self.thread.join() + + @override + def state_dict(self): + self.join() + + if self.save_original_optimizer_state: + return self.optimizer.state_dict() + + # if we are in the context of saving an EMA model, the EMA weights are in the modules' actual weights + ema_params = ( + self.ema_params + if not self.in_saving_ema_model_context + else list(self.all_parameters()) + ) + state_dict = { + "opt": self.optimizer.state_dict(), + "ema": ema_params, + "current_step": self.current_step, + "decay": self.decay, + "every_n_steps": self.every_n_steps, + } + return state_dict + + @override + def load_state_dict(self, state_dict): + self.join() + + self.optimizer.load_state_dict(state_dict["opt"]) + self.ema_params = tuple( + param.to(self.device) for param in copy.deepcopy(state_dict["ema"]) + ) + self.current_step = state_dict["current_step"] + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + self.rebuild_ema_params = False + + @override + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + self.rebuild_ema_params = True + + +class EMAConfig(TypedConfig): + decay: float + validate_original_weights: bool = False + every_n_steps: int = 1 + cpu_offload: bool = False + + def construct_callback(self): + return EMA( + decay=self.decay, + validate_original_weights=self.validate_original_weights, + every_n_steps=self.every_n_steps, + cpu_offload=self.cpu_offload, + ) diff --git a/src/jmp/modules/metadata.py b/src/jmp/modules/metadata.py new file mode 100644 index 0000000..5eb047b --- /dev/null +++ b/src/jmp/modules/metadata.py @@ -0,0 +1,104 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging +from pathlib import Path +from typing import Protocol, runtime_checkable + +import numpy as np +import torch +from torch.utils.data import ConcatDataset, Dataset + + +@runtime_checkable +class _DatasetWrapper(Protocol): + @property + def dataset(self) -> Dataset: ... + + +@runtime_checkable +class _HasMetadataPath(Protocol): + @property + def metadata_path(self) -> Path: ... + + +@runtime_checkable +class _HasMetadataProperty(Protocol): + @property + def atoms_metadata(self) -> torch.Tensor | np.ndarray: ... + + +def _dataset_repr(dataset): + if isinstance(dataset, _DatasetWrapper): + name = dataset.__class__.__qualname__ + inner_repr = _dataset_repr(dataset.dataset) + return f"{name}({inner_repr})" + + return dataset.__class__.__qualname__ + + +def _get_metadata_property(dataset): + if isinstance(dataset, _HasMetadataProperty): + return ( + dataset.atoms_metadata.numpy() + if isinstance(dataset.atoms_metadata, torch.Tensor) + else dataset.atoms_metadata + ) + + return None + + +def _get_metadata_path(dataset): + if isinstance(dataset, _HasMetadataPath): + if not dataset.metadata_path.exists(): + return None + return np.load(dataset.metadata_path)["natoms"] + + return None + + +def get_metadata( + dataset, + *, + strict: bool, +): + fns = [_get_metadata_property, _get_metadata_path] + metadata = next((fn(dataset) for fn in fns if fn(dataset) is not None), None) + if metadata is None: + if strict: + raise RuntimeError(f"Failed to load metadata for {_dataset_repr(dataset)}") + else: + logging.warning(f"Failed to load metadata for {_dataset_repr(dataset)}") + return metadata + + +def post_create_dataset(dataset: Dataset, *, strict: bool): + if isinstance(dataset, (ConcatDataset)): + metadata = [get_metadata(d, strict=strict) for d in dataset.datasets] + all_set = True + for dataset_idx, (m, dataset_) in enumerate(zip(metadata, dataset.datasets)): + if m is not None: + logging.debug( + f"Loaded metadata of size ({len(m)}) for {_dataset_repr(dataset_)}" + ) + continue + + all_set = False + error_msg = ( + f"Failed to load metadata for {dataset_idx=}: {_dataset_repr(dataset_)}" + ) + if strict: + raise RuntimeError(error_msg) + else: + logging.warning(error_msg) + if all_set and False: + setattr(dataset, "atoms_metadata", np.concatenate(metadata)) + else: + metadata = get_metadata(dataset, strict=strict) + if strict and metadata is None: + raise RuntimeError(f"Failed to load metadata for {_dataset_repr(dataset)}") diff --git a/src/jmp/modules/metrics.py b/src/jmp/modules/metrics.py new file mode 100644 index 0000000..cdec5d6 --- /dev/null +++ b/src/jmp/modules/metrics.py @@ -0,0 +1,174 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Callable +from functools import partial +from typing import TypedDict + +import torch +import torch.nn as nn +import torchmetrics +from jmp.lightning.util.typed import TypedModuleList +from torch_geometric.data import Batch +from typing_extensions import NotRequired, override + +from .transforms.normalize import denormalize_batch +from .transforms.units import VALID_UNITS, Unit, _determine_factor + + +class MetricConfig(TypedDict): + idx: int + additional_units: NotRequired[list[str]] + + +def _transform(x: torch.Tensor, *, from_: Unit, to: Unit): + factor = _determine_factor(from_, to) + return x * factor + + +class FMTaskMetrics(nn.Module): + @override + def __init__( + self, + name: str, + config: MetricConfig, + num_tasks: int, + free_atoms_only: bool = True, + ): + super().__init__() + + self.name = name + self.config = config + self.num_tasks = num_tasks + self.free_atoms_only = free_atoms_only + + self.energy_mae = torchmetrics.MeanAbsoluteError() + self.forces_mae = torchmetrics.MeanAbsoluteError() + + if units := self.config.get("additional_units", []): + for unit in units: + if unit not in VALID_UNITS: + raise ValueError( + f"Invalid unit: {unit}. Valid units: {VALID_UNITS}" + ) + self.energy_mae_additional = TypedModuleList( + [torchmetrics.MeanAbsoluteError() for _ in units] + ) + self.forces_mae_additional = TypedModuleList( + [torchmetrics.MeanAbsoluteError() for _ in units] + ) + + @override + def forward(self, batch: Batch, energy: torch.Tensor, forces: torch.Tensor): + metrics: dict[str, torchmetrics.Metric] = {} + + self._energy_mae(batch, energy, self.energy_mae) + self._forces_mae(batch, forces, self.forces_mae) + + metrics["energy_mae"] = self.energy_mae + metrics["forces_mae"] = self.forces_mae + + if additional := self.config.get("additional_units", []): + for unit, energy_metric, forces_metric in zip( + additional, self.energy_mae_additional, self.forces_mae_additional + ): + assert ( + unit in VALID_UNITS + ), f"Invalid unit: {unit}. Valid units: {VALID_UNITS}" + sanitized_unit = unit.replace("/", "_") + self._energy_mae( + batch, + energy, + energy_metric, + transform=partial(_transform, from_="eV", to=unit), + ) + self._forces_mae( + batch, + forces, + forces_metric, + transform=partial(_transform, from_="eV", to=unit), + ) + + metrics[f"energy_mae_{sanitized_unit}"] = energy_metric + metrics[f"forces_mae_{sanitized_unit}"] = forces_metric + + return {f"{self.name}/{name}": metric for name, metric in metrics.items()} + + def _forces_mae( + self, + batch: Batch, + forces: torch.Tensor, + forces_mae: torchmetrics.MeanAbsoluteError, + *, + transform: Callable[[torch.Tensor], torch.Tensor] | None = None, + ): + task_idx = self.config["idx"] + + forces_mask = batch.task_mask[:, task_idx] # (b,) + forces_mask = forces_mask[batch.batch] # (n,) + if self.free_atoms_only: + forces_mask = forces_mask & ~batch.fixed + forces_target = batch.force[..., task_idx][forces_mask] + forces_pred = forces[..., task_idx][forces_mask] + if transform is not None: + forces_target = transform(forces_target) + forces_pred = transform(forces_pred) + + forces_mae(forces_pred, forces_target) + + def _energy_mae( + self, + batch: Batch, + energy: torch.Tensor, + energy_mae: torchmetrics.MeanAbsoluteError, + *, + transform: Callable[[torch.Tensor], torch.Tensor] | None = None, + ): + task_idx = self.config["idx"] + + energy_mask = batch.task_mask[:, task_idx] # (b,) + energy_target = batch.y[..., task_idx][energy_mask] # (b,) + energy_pred = energy[..., task_idx][energy_mask] # (b,) + if transform is not None: + energy_target = transform(energy_target) + energy_pred = transform(energy_pred) + + energy_mae(energy_pred, energy_target) + + +class FMMetrics(nn.Module): + @override + def __init__( + self, + tasks: dict[str, MetricConfig], + *, + denormalize: bool, + free_atoms_only: bool = True, + ): + super().__init__() + + self.denormalize = denormalize + self.task_metrics = TypedModuleList( + [ + FMTaskMetrics( + name, config, num_tasks=len(tasks), free_atoms_only=free_atoms_only + ) + for name, config in tasks.items() + ] + ) + + @override + def forward(self, batch: Batch, energy: torch.Tensor, forces: torch.Tensor): + if self.denormalize: + batch, d = denormalize_batch(batch, {"y": energy, "force": forces}) + energy, forces = d["y"], d["force"] + + metrics: dict[str, torchmetrics.Metric] = {} + for task_metrics in self.task_metrics: + metrics.update(task_metrics(batch, energy, forces)) + return metrics diff --git a/src/jmp/modules/scaling/__init__.py b/src/jmp/modules/scaling/__init__.py new file mode 100644 index 0000000..d8e9d01 --- /dev/null +++ b/src/jmp/modules/scaling/__init__.py @@ -0,0 +1,11 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .scale_factor import ScaleFactor + +__all__ = ["ScaleFactor"] diff --git a/src/jmp/modules/scaling/compat.py b/src/jmp/modules/scaling/compat.py new file mode 100644 index 0000000..efd6391 --- /dev/null +++ b/src/jmp/modules/scaling/compat.py @@ -0,0 +1,88 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import json +import logging +from pathlib import Path +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn + +from .scale_factor import ScaleFactor + +ScaleDict = Union[Dict[str, float], Dict[str, torch.Tensor]] + + +def _load_scale_dict(scale_file: Optional[Union[str, ScaleDict]]): + """ + Loads scale factors from either: + - a JSON file mapping scale factor names to scale values + - a python dictionary pickled object (loaded using `torch.load`) mapping scale factor names to scale values + - a dictionary mapping scale factor names to scale values + """ + if not scale_file: + return None + + if isinstance(scale_file, dict): + if not scale_file: + logging.warning("Empty scale dictionary provided to model.") + return scale_file + + path = Path(scale_file) + if not path.exists(): + raise ValueError(f"Scale file {path} does not exist.") + + scale_dict: Optional[ScaleDict] = None + if path.suffix == ".pt": + scale_dict = torch.load(path) + elif path.suffix == ".json": + with open(path, "r") as f: + scale_dict = json.load(f) + + if isinstance(scale_dict, dict): + # old json scale factors have a comment field that has the model name + scale_dict.pop("comment", None) + else: + raise ValueError(f"Unsupported scale file extension: {path.suffix}") + + if not scale_dict: + return None + + return scale_dict + + +def load_scales_compat(module: nn.Module, scale_file: Optional[Union[str, ScaleDict]]): + scale_dict = _load_scale_dict(scale_file) + if not scale_dict: + return + + scale_factors = { + module.name or name: (module, name) + for name, module in module.named_modules() + if isinstance(module, ScaleFactor) + } + loaded_factors = set[str]() + logging.critical( + f"Found the following scale factors: {[(k, name) for k, (_, name) in scale_factors.items()]}." + ) + for name, scale in scale_dict.items(): + if name not in scale_factors: + logging.warning(f"Scale factor {name} not found in model") + continue + + scale_module, module_name = scale_factors[name] + logging.debug(f"Loading scale factor {scale} for ({name} => {module_name})") + scale_module.set_(scale) + loaded_factors.add(name) + + not_loaded_factors = set(scale_factors.keys()) - loaded_factors + logging.critical( + f"Loaded the following scale factors: {list(loaded_factors)}.\n" + f"Did not load the following scale factors: {list(not_loaded_factors)}." + ) diff --git a/src/jmp/modules/scaling/scale_factor.py b/src/jmp/modules/scaling/scale_factor.py new file mode 100644 index 0000000..53fa32a --- /dev/null +++ b/src/jmp/modules/scaling/scale_factor.py @@ -0,0 +1,173 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import itertools +import logging +import math +from collections.abc import Callable +from contextlib import contextmanager +from typing import TypedDict + +import torch +import torch.nn as nn + + +class _Stats(TypedDict): + variance_in: float + variance_out: float + n_samples: int + + +IndexFn = Callable[[], None] + + +def _check_consistency(old: torch.Tensor, new: torch.Tensor, key: str): + if not torch.allclose(old, new): + raise ValueError( + f"Scale factor parameter {key} is inconsistent with the loaded state dict.\n" + f"Old: {old}\n" + f"Actual: {new}" + ) + + +class ScaleFactor(nn.Module): + scale_factor: torch.Tensor + + name: str | None = None + index_fn: IndexFn | None = None + stats: _Stats | None = None + + def __init__( + self, + name: str | None = None, + enforce_consistency: bool = True, + ): + super().__init__() + + self.name = name + + self.index_fn = None + self.stats = None + self.register_buffer("scale_factor", torch.tensor(0.0)) + if enforce_consistency: + _ = self._register_load_state_dict_pre_hook(self._enforce_consistency) + + def _enforce_consistency( + self, + state_dict, + prefix, + _local_metadata, + _strict, + _missing_keys, + _unexpected_keys, + _error_msgs, + ): + if not self.fitted: + return + + persistent_buffers = { + k: v + for k, v in self._buffers.items() + if k not in self._non_persistent_buffers_set + } + local_name_params = itertools.chain( + self._parameters.items(), persistent_buffers.items() + ) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key not in state_dict: + continue + + input_param = state_dict[key] + _check_consistency(old=param, new=input_param, key=key) + + @property + def fitted(self): + return bool((self.scale_factor != 0.0).item()) + + @torch.jit.unused + def reset_(self): + self.scale_factor.zero_() + + @torch.jit.unused + def set_(self, scale: float | torch.Tensor): + if self.fitted: + _check_consistency( + old=self.scale_factor, + new=torch.tensor(scale) if isinstance(scale, (float, int)) else scale, + key="scale_factor", + ) + self.scale_factor.fill_(scale) + + @torch.jit.unused + def initialize_(self, *, index_fn: IndexFn | None = None): + self.index_fn = index_fn + + @contextmanager + @torch.jit.unused + def fit_context_(self): + self.stats = _Stats(variance_in=0.0, variance_out=0.0, n_samples=0) + yield + del self.stats + self.stats = None + + @torch.jit.unused + def fit_(self): + assert self.stats, "Stats not set" + for k, v in self.stats.items(): + assert v > 0, f"{k} is {v}" + + self.stats["variance_in"] = self.stats["variance_in"] / self.stats["n_samples"] + self.stats["variance_out"] = ( + self.stats["variance_out"] / self.stats["n_samples"] + ) + + ratio = self.stats["variance_out"] / self.stats["variance_in"] + value = math.sqrt(1 / ratio) + + self.set_(value) + + stats = dict(**self.stats) + return stats, ratio, value + + @torch.no_grad() + @torch.jit.unused + def _observe(self, x: torch.Tensor, ref: torch.Tensor | None = None): + if self.stats is None: + logging.debug("Observer not initialized but self.observe() called") + return + + n_samples = x.shape[0] + self.stats["variance_out"] += torch.mean(torch.var(x, dim=0)).item() * n_samples + + if ref is None: + self.stats["variance_in"] += n_samples + else: + self.stats["variance_in"] += ( + torch.mean(torch.var(ref, dim=0)).item() * n_samples + ) + self.stats["n_samples"] += n_samples + + def forward( + self, + x: torch.Tensor, + *, + ref: torch.Tensor | None = None, + ): + if self.index_fn is not None: + self.index_fn() + + if self.fitted: + x = x * self.scale_factor + + if not torch.jit.is_scripting(): + self._observe(x, ref=ref) + + return x diff --git a/src/jmp/modules/scaling/util.py b/src/jmp/modules/scaling/util.py new file mode 100644 index 0000000..57bf800 --- /dev/null +++ b/src/jmp/modules/scaling/util.py @@ -0,0 +1,31 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import logging + +import torch.nn as nn + +from .scale_factor import ScaleFactor + + +def ensure_fitted(module: nn.Module, warn: bool = False): + for name, child in module.named_modules(): + if not isinstance(child, ScaleFactor) or child.fitted: + continue + if child.name is not None: + name = f"{child.name} ({name})" + msg = ( + f"Scale factor {name} is not fitted. " + "Please make sure that you either (1) load a checkpoint with fitted scale factors, " + "(2) explicitly load scale factors using the `model.scale_file` attribute, or " + "(3) fit the scale factors using the `fit.py` script." + ) + if warn: + logging.warning(msg) + else: + raise ValueError(msg) diff --git a/src/jmp/modules/scheduler/gradual_warmup_lr.py b/src/jmp/modules/scheduler/gradual_warmup_lr.py new file mode 100644 index 0000000..1ea7b3b --- /dev/null +++ b/src/jmp/modules/scheduler/gradual_warmup_lr.py @@ -0,0 +1,82 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler + + +class GradualWarmupScheduler(_LRScheduler): + def __init__( + self, + optimizer: Optimizer, + warmup_start_lr: float, + warmup_steps: int, + after_scheduler=None, + ): + self.warmup_start_lr = warmup_start_lr + self.warmup_steps = warmup_steps + self.after_scheduler = after_scheduler + self.finished = False + + super().__init__(optimizer) + + def get_lr(self): + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + + if self.last_epoch < self.warmup_steps: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_steps - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + if self.last_epoch != self.warmup_steps and self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = self.base_lrs + self.finished = True + return self.after_scheduler.get_last_lr() + return self.base_lrs + + def step_ReduceLROnPlateau(self, metrics, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = ( + epoch if epoch != 0 else 1 + ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning + if self.last_epoch <= self.warmup_steps: + warmup_lr = [ + base_lr + * ( + (self.warmup_start_lr - 1.0) * self.last_epoch / self.warmup_steps + + 1.0 + ) + for base_lr in self.base_lrs + ] + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): + param_group["lr"] = lr + else: + assert self.after_scheduler is not None and isinstance( + self.after_scheduler, ReduceLROnPlateau + ) + if epoch is None: + self.after_scheduler.step(metrics, None) + else: + self.after_scheduler.step(metrics, epoch - self.warmup_steps) + + def step(self, epoch=None, metrics=None): + if not isinstance(self.after_scheduler, ReduceLROnPlateau): + if self.finished and self.after_scheduler: + if epoch is None: + self.after_scheduler.step(None) + else: + self.after_scheduler.step(epoch - self.warmup_steps) + self._last_lr = self.after_scheduler.get_last_lr() + else: + return super(GradualWarmupScheduler, self).step(epoch) + else: + self.step_ReduceLROnPlateau(metrics, epoch) diff --git a/src/jmp/modules/scheduler/linear_warmup_cos_rlp.py b/src/jmp/modules/scheduler/linear_warmup_cos_rlp.py new file mode 100644 index 0000000..beb449f --- /dev/null +++ b/src/jmp/modules/scheduler/linear_warmup_cos_rlp.py @@ -0,0 +1,77 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from logging import getLogger +from typing import Any + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import ReduceLROnPlateau +from typing_extensions import override + +from .linear_warmup_cosine_annealing import PerParamGroupLinearWarmupCosineAnnealingLR + +log = getLogger(__name__) + + +class PerParamGroupLinearWarmupCosineAnnealingRLPLR( + PerParamGroupLinearWarmupCosineAnnealingLR +): + def __init__( + self, + optimizer: Optimizer, + param_group_settings: list[dict[str, Any]] | dict[str, Any], + rlp_settings: dict, + max_epochs: int, + last_epoch: int = -1, + ) -> None: + self.rlp_start_epoch = max_epochs + self.rlp = ReduceLROnPlateau(optimizer, **rlp_settings) + self.global_step: int | None = None + + super().__init__( + optimizer, + param_group_settings, + last_epoch, + ) + + for settings in self.param_group_settings: + should_restart = settings.get("should_restart", False) + assert ( + not should_restart + ), "If you want to use RLP, set should_restart=False." + + max_epochs_setting = settings.get("max_epochs", None) + if max_epochs_setting is not None: + assert ( + max_epochs_setting == max_epochs + ), f"max_epochs must be {max_epochs}" + else: + settings["max_epochs"] = max_epochs + + def on_new_step(self, global_step: int): + self.global_step = global_step + log.debug(f"global_step: {self.global_step}") + + def is_in_rlp_stage(self, global_step: int | None = None): + if global_step is None: + global_step = self.global_step + if global_step is None: + global_step = self.last_epoch + return global_step >= self.rlp_start_epoch + + def rlp_step(self, metric: float | torch.Tensor): + return self.rlp.step(metric) + + @override + def step(self, metrics=None, epoch=None): + assert metrics is None, f"metrics must be None but got {metrics}" + if self.is_in_rlp_stage(): + return + + return super().step(epoch=epoch) diff --git a/src/jmp/modules/scheduler/linear_warmup_cosine_annealing.py b/src/jmp/modules/scheduler/linear_warmup_cosine_annealing.py new file mode 100644 index 0000000..f2095ac --- /dev/null +++ b/src/jmp/modules/scheduler/linear_warmup_cosine_annealing.py @@ -0,0 +1,175 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math +import warnings +from typing import Any + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + + +class LinearWarmupCosineAnnealingLR(_LRScheduler): + def __init__( + self, + optimizer: Optimizer, + warmup_epochs: int, + max_epochs: int, + warmup_start_lr: float = 0.0, + eta_min: float = 0.0, + last_epoch: int = -1, + should_restart: bool = True, + ) -> None: + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + self.should_restart = should_restart + + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + if self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + if self.last_epoch == self.warmup_epochs: + return self.base_lrs + + if not self.should_restart and self.last_epoch >= self.max_epochs: + return [self.eta_min] * len(self.base_lrs) + + if (self.last_epoch - 1 - self.max_epochs) % ( + 2 * (self.max_epochs - self.warmup_epochs) + ) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) + * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + ( + 1 + + math.cos( + math.pi + * (self.last_epoch - self.warmup_epochs) + / (self.max_epochs - self.warmup_epochs) + ) + ) + / ( + 1 + + math.cos( + math.pi + * (self.last_epoch - self.warmup_epochs - 1) + / (self.max_epochs - self.warmup_epochs) + ) + ) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + +class PerParamGroupLinearWarmupCosineAnnealingLR(_LRScheduler): + def __init__( + self, + optimizer: Optimizer, + param_group_settings: list[dict[str, Any]] | dict[str, Any], + last_epoch: int = -1, + ) -> None: + if isinstance(param_group_settings, dict): + param_group_settings = [param_group_settings] * len(optimizer.param_groups) + + if len(param_group_settings) != len(optimizer.param_groups): + raise ValueError( + "Number of elements in param_group_settings must match the number of parameter groups in the optimizer." + ) + + self.param_group_settings = param_group_settings + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> list[float]: + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + UserWarning, + ) + + new_lrs = [] + for group, settings in zip( + self.optimizer.param_groups, self.param_group_settings + ): + warmup_epochs = settings["warmup_epochs"] + max_epochs = settings["max_epochs"] + warmup_start_lr = settings.get("warmup_start_lr", 0.0) + eta_min = settings.get("eta_min", 0.0) + should_restart = settings.get("should_restart", True) + + if self.last_epoch == 0: + new_lrs.append(warmup_start_lr) + elif self.last_epoch < warmup_epochs: + new_lr = group["lr"] + (group["initial_lr"] - warmup_start_lr) / ( + warmup_epochs - 1 + ) + new_lrs.append(new_lr) + elif self.last_epoch == warmup_epochs: + new_lrs.append(group["initial_lr"]) + elif not should_restart and self.last_epoch >= max_epochs: + new_lrs.append(eta_min) + else: + new_lr = eta_min + 0.5 * (group["initial_lr"] - eta_min) * ( + 1 + + math.cos( + math.pi + * (self.last_epoch - warmup_epochs) + / (max_epochs - warmup_epochs) + ) + ) + new_lrs.append(new_lr) + + return new_lrs + + +def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False): + """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps.""" + assert not (linear and cosine) + + def fn(step): + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + + if not (cosine or linear): + # no decay + return 1.0 + + progress = float(step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + if cosine: + # cosine decay + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + # linear decay + return 1.0 - progress + + return fn diff --git a/src/jmp/modules/transforms/__init__.py b/src/jmp/modules/transforms/__init__.py new file mode 100644 index 0000000..98f9821 --- /dev/null +++ b/src/jmp/modules/transforms/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .atom_ref import atomref_transform as atomref_transform +from .normalize import NormalizationConfig as NormalizationConfig +from .normalize import denormalize_batch as denormalize_batch +from .normalize import normalize as normalize +from .units import update_pyg_data_units as update_pyg_data_units +from .units import update_units_transform as update_units_transform +from .utils import compose as compose diff --git a/src/jmp/modules/transforms/atom_ref.py b/src/jmp/modules/transforms/atom_ref.py new file mode 100644 index 0000000..6cfa665 --- /dev/null +++ b/src/jmp/modules/transforms/atom_ref.py @@ -0,0 +1,26 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import torch +from torch_geometric.data.data import BaseData + + +def atomref_transform( + data: BaseData, + refs: dict[str, torch.Tensor], + keep_raw: bool = False, +): + z: torch.Tensor = data.atomic_numbers + for target, coeffs in refs.items(): + value = getattr(data, target) + if keep_raw: + setattr(data, f"{target}_raw", value.clone()) + value = value - coeffs[z].sum() + setattr(data, target, value) + + return data diff --git a/src/jmp/modules/transforms/normalize.py b/src/jmp/modules/transforms/normalize.py new file mode 100644 index 0000000..03145ab --- /dev/null +++ b/src/jmp/modules/transforms/normalize.py @@ -0,0 +1,83 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Mapping +from typing import cast + +import numpy as np +import torch +from jmp.lightning import TypedConfig +from torch_geometric.data.data import BaseData +from typing_extensions import TypeVar + +T = TypeVar("T", float, torch.Tensor, np.ndarray, infer_variance=True) + + +def _process_value(value: T) -> torch.Tensor: + return cast( + torch.Tensor, + torch.tensor(value) if not torch.is_tensor(value) else value, + ) + + +class NormalizationConfig(TypedConfig): + mean: float = 1.0 + std: float = 1.0 + + def normalize(self, value: T) -> T: + return (value - self.mean) / self.std + + def denormalize(self, value: T) -> T: + return (value * self.std) + self.mean + + +def normalize(properties: Mapping[str, NormalizationConfig]): + def _normalize(data: BaseData): + nonlocal properties + + for key, d in properties.items(): + if (value := getattr(data, key, None)) is None: + raise ValueError(f"Property {key} not found in data") + + value = _process_value(value) + value = d.normalize(value) + setattr(data, key, value) + setattr(data, f"{key}_norm_mean", torch.full_like(value, d.mean)) + setattr(data, f"{key}_norm_std", torch.full_like(value, d.std)) + + return data + + return _normalize + + +def denormalize_batch( + batch: BaseData, + additional_tensors: dict[str, torch.Tensor] | None = None, +): + if additional_tensors is None: + additional_tensors = {} + + keys: set[str] = set(batch.keys()) + + # find all keys that have a denorm_mean and denorm_std + norm_keys: set[str] = { + key.replace("_norm_mean", "") for key in keys if key.endswith("_norm_mean") + } & {key.replace("_norm_std", "") for key in keys if key.endswith("_norm_std")} + + for key in norm_keys: + mean = getattr(batch, f"{key}_norm_mean") + std = getattr(batch, f"{key}_norm_std") + value = getattr(batch, key) + + value = (value * std) + mean + setattr(batch, key, value) + + if (additional_value := additional_tensors.pop(key, None)) is not None: + additional_tensors[key] = (additional_value * std) + mean + + return batch, additional_tensors diff --git a/src/jmp/modules/transforms/units.py b/src/jmp/modules/transforms/units.py new file mode 100644 index 0000000..1ec7053 --- /dev/null +++ b/src/jmp/modules/transforms/units.py @@ -0,0 +1,79 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Literal + +from torch_geometric.data.data import BaseData +from typing_extensions import TypeVar + +VALID_UNITS = ("eV", "kcal/mol", "hartree", "bohr", "angstrom") +Unit = Literal["eV", "kcal/mol", "hartree", "bohr", "angstrom"] + + +def _determine_factor(from_: Unit, to: Unit, *, reciprocal: bool = False): + if from_ == to: + return 1.0 + + match (from_, to): + case ("eV", "kcal/mol"): + factor = 23.061 + case ("eV", "hartree"): + factor = 0.0367493 + case ("kcal/mol", "eV"): + factor = 1 / 23.061 + case ("kcal/mol", "hartree"): + factor = 1 / 627.509 + case ("hartree", "eV"): + factor = 1 / 0.0367493 + case ("hartree", "kcal/mol"): + factor = 627.509 + case ("bohr", "angstrom"): + factor = 0.529177 + case ("angstrom", "bohr"): + factor = 1 / 0.529177 + case _: + raise ValueError(f"Cannot convert {from_} to {to}") + + return 1 / factor if reciprocal else factor + + +T = TypeVar("T", bound=BaseData, infer_variance=True) + + +def update_units_transform( + data: T, + attributes: list[str] = ["y", "force"], + *, + from_: Unit, + to: Unit, + reciprocal: bool = False, +) -> T: + factor = _determine_factor(from_, to, reciprocal=reciprocal) + + for attr in attributes: + if (value := getattr(data, attr, None)) is None: + continue + setattr(data, attr, value * factor) + + return data + + +def update_pyg_data_units( + data: BaseData, + attributes: list[str], + *, + from_: Unit, + to: Unit, +): + factor = _determine_factor(from_, to) + for attr in attributes: + if (value := getattr(data, attr, None)) is None: + continue + setattr(data, attr, value * factor) + + return data diff --git a/src/jmp/modules/transforms/utils.py b/src/jmp/modules/transforms/utils.py new file mode 100644 index 0000000..54c4803 --- /dev/null +++ b/src/jmp/modules/transforms/utils.py @@ -0,0 +1,22 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Callable + +from torch_geometric.data.data import BaseData + +Transform = Callable[[BaseData], BaseData] + + +def compose(transforms: list[Transform]): + def composed(data: BaseData): + for transform in transforms: + data = transform(data) + return data + + return composed diff --git a/src/jmp/tasks/config.py b/src/jmp/tasks/config.py new file mode 100644 index 0000000..31266a3 --- /dev/null +++ b/src/jmp/tasks/config.py @@ -0,0 +1,159 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from collections.abc import Iterable +from dataclasses import dataclass, field +from logging import getLogger +from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast + +import torch.nn as nn +import torch.optim as optim +from jmp.lightning import Field, TypedConfig + +log = getLogger(__name__) + + +class OutputConfig(TypedConfig): + num_mlps: int = 5 + """Number of MLPs in the output layer.""" + output_init: Literal["HeOrthogonal", "zeros", "grid", "loggrid"] = "HeOrthogonal" + """Initialization method for the output layer.""" + + +class AdamWConfig(TypedConfig): + name: Literal["adamw"] = "adamw" + + lr: float + """Learning rate for the optimizer.""" + + weight_decay: float = 1.0e-2 + """Weight decay (L2 penalty) for the optimizer.""" + + betas: tuple[float, float] = (0.9, 0.999) + """ + Betas for the optimizer: + (beta1, beta2) are the coefficients used for computing running averages of + gradient and its square. + """ + + eps: float = 1e-8 + """Term added to the denominator to improve numerical stability.""" + + amsgrad: bool = False + """Whether to use the AMSGrad variant of this algorithm.""" + + +@dataclass(frozen=True) +class _OptimizerParamGroupConfig: + cls: type[optim.Optimizer] + param_group_kwargs: dict[str, Any] = field(default_factory=lambda: {}) + optimizer_kwargs: dict[str, Any] = field(default_factory=lambda: {}) + + +OptimizerConfig: TypeAlias = Annotated[AdamWConfig, Field(discriminator="name")] + + +class EmbeddingConfig(TypedConfig): + num_elements: int + embedding_size: int + + +def _create_dict_from_config( + config: OptimizerConfig, + params: Iterable[nn.Parameter], + name: str | None = None, +): + # This is a hack to get type hints for the kwargs + # of the module while actually returning a dict. + from torch.optim import AdamW + + AdamWKwargs = AdamW + + if not TYPE_CHECKING: + AdamWKwargs = dict + + if config.lr <= 0: + raise ValueError(f"Learning rate must be positive, got {config.lr}") + + kwargs = cast( + dict, + AdamWKwargs( + params=params, + lr=config.lr, + amsgrad=config.amsgrad, + weight_decay=config.weight_decay, + betas=config.betas, + eps=config.eps, + ), + ) + if name is not None: + kwargs["name"] = name + return _OptimizerParamGroupConfig(AdamW, param_group_kwargs=kwargs) + + +def optimizer_from_config( + param_groups: list[tuple[OptimizerConfig, Iterable[nn.Parameter]]] + | list[tuple[OptimizerConfig, Iterable[nn.Parameter], str | None]], + *, + base: "OptimizerConfig | None" = None, +): + configs = [ + _create_dict_from_config( + param_group[0], + param_group[1], + name=param_group[2] if len(param_group) == 3 else None, + ) + for param_group in param_groups + ] + optimizer_cls_list = [c.cls for c in configs] + assert len(set(optimizer_cls_list)) == 1, "All optimizers must be of the same type" + optimizer_cls = optimizer_cls_list[0] + + optimizer_kwargs_list = [c.optimizer_kwargs for c in configs] + assert ( + len(set(map(str, optimizer_kwargs_list))) == 1 + ), "All optimizers must have the same kwargs" + optimizer_kwargs = optimizer_kwargs_list[0] + + base_kwargs = {} + if base is not None: + base_config = _create_dict_from_config(base, []) + assert ( + base_config.cls == optimizer_cls + ), "Base optimizer must be of the same type" + _ = base_config.param_group_kwargs.pop("params", None) + base_kwargs.update(base_config.param_group_kwargs) + + param_groups_configs = [c.param_group_kwargs for c in configs] + optimizer = optimizer_cls( + params=param_groups_configs, + **optimizer_kwargs, + **base_kwargs, + ) + # detailed log about the optimizer configuration + param_groups_logs: list[str] = [] + for i, c in enumerate(param_groups_configs): + c = copy.deepcopy(c) + params = c.pop("params", None) + n_params = len(params) if params is not None else 0 + total_param_size = sum(p.numel() for p in params) if params is not None else 0 + param_groups_logs.append( + f"Param group {i}:\n" + f" Params: {n_params}\n" + f" Total param size: {total_param_size}\n" + f" Other kwargs: {c}" + ) + param_groups_log = "\n".join(param_groups_logs) + log.critical( + f"Optimizer: {optimizer_cls.__name__}\n" + f"Optimizer kwargs: {optimizer_kwargs}\n" + f"Base kwargs: {base_kwargs}\n" + f"Param groups: {param_groups_log}" + ) + return optimizer diff --git a/src/jmp/tasks/finetune/__init__.py b/src/jmp/tasks/finetune/__init__.py new file mode 100644 index 0000000..0415cbc --- /dev/null +++ b/src/jmp/tasks/finetune/__init__.py @@ -0,0 +1,35 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .base import FinetuneConfigBase, FinetuneModelBase +from .matbench import MatbenchConfig, MatbenchModel +from .md22 import MD22Config, MD22Model +from .pdbbind import PDBBindConfig, PDBBindModel +from .qm9 import QM9Config, QM9Model +from .qmof import QMOFConfig, QMOFModel +from .rmd17 import RMD17Config, RMD17Model +from .spice import SPICEConfig, SPICEModel + +__all__ = [ + "FinetuneConfigBase", + "FinetuneModelBase", + "MatbenchConfig", + "MatbenchModel", + "MD22Config", + "MD22Model", + "PDBBindConfig", + "PDBBindModel", + "QM9Config", + "QM9Model", + "QMOFConfig", + "QMOFModel", + "RMD17Config", + "RMD17Model", + "SPICEConfig", + "SPICEModel", +] diff --git a/src/jmp/tasks/finetune/base.py b/src/jmp/tasks/finetune/base.py new file mode 100644 index 0000000..7593fda --- /dev/null +++ b/src/jmp/tasks/finetune/base.py @@ -0,0 +1,1740 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import fnmatch +import itertools +import math +from abc import abstractmethod +from collections.abc import Iterable, Mapping +from functools import partial +from logging import getLogger +from pathlib import Path +from typing import Annotated, Any, Generic, Literal, TypeAlias, assert_never, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from jmp.lightning import Base, BaseConfig, Field, LightningModuleBase, TypedConfig +from jmp.lightning.data.balanced_batch_sampler import ( + BalancedBatchSampler, + DatasetWithSizes, +) +from jmp.lightning.util.typed import TypedModuleDict +from lightning.pytorch.callbacks import ModelCheckpoint +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from torch_geometric.data.batch import Batch +from torch_geometric.data.data import BaseData +from torch_scatter import scatter +from typing_extensions import TypedDict, TypeVar, override + +from ...datasets.finetune.base import LmdbDataset +from ...datasets.finetune_pdbbind import PDBBindConfig, PDBBindDataset +from ...models.gemnet.backbone import GemNetOCBackbone, GOCBackboneOutput +from ...models.gemnet.config import BackboneConfig +from ...models.gemnet.layers.base_layers import ScaledSiLU +from ...modules import transforms as T +from ...modules.dataset import dataset_transform as DT +from ...modules.dataset.common import CommonDatasetConfig, wrap_common_dataset +from ...modules.early_stopping import EarlyStoppingWithMinLR +from ...modules.ema import EMAConfig +from ...modules.scheduler.gradual_warmup_lr import GradualWarmupScheduler +from ...modules.scheduler.linear_warmup_cos_rlp import ( + PerParamGroupLinearWarmupCosineAnnealingRLPLR, +) +from ...modules.transforms.normalize import NormalizationConfig +from ...utils.goc_graph import ( + Cutoffs, + Graph, + MaxNeighbors, + generate_graph, + subselect_graph, + tag_mask, +) +from ...utils.state_dict import load_state_dict +from ..config import ( + EmbeddingConfig, + OptimizerConfig, + OutputConfig, + optimizer_from_config, +) +from .metrics import FinetuneMetrics, MetricPair, MetricsConfig + +log = getLogger(__name__) + +DatasetType: TypeAlias = LmdbDataset + + +class RLPWarmupConfig(TypedConfig): + steps: int + """Number of steps for the warmup""" + + start_lr_factor: float + """The factor to multiply the initial learning rate by at the start of the warmup""" + + +class RLPConfig(TypedConfig): + name: Literal["rlp"] = "rlp" + + monitor: str | None = None + mode: str | None = None + patience: int = 10 + factor: float = 0.1 + min_lr: float = 0.0 + eps: float = 1.0e-8 + cooldown: int = 0 + threshold: float = 1.0e-4 + threshold_mode: str = "rel" + interval: str = "epoch" + frequency: int = 1 + warmup: RLPWarmupConfig | None = None + + def _to_linear_warmup_cos_rlp_dict(self): + """ + Params for PerParamGroupLinearWarmupCosineAnnealingRLPLR's RLP + mode="min", + factor=0.1, + patience=10, + threshold=1e-4, + threshold_mode="rel", + cooldown=0, + min_lr=0, + eps=1e-8, + verbose=False, + """ + return { + "mode": self.mode, + "factor": self.factor, + "patience": self.patience, + "threshold": self.threshold, + "threshold_mode": self.threshold_mode, + "cooldown": self.cooldown, + "min_lr": self.min_lr, + "eps": self.eps, + "verbose": True, + } + + +class WarmupCosRLPConfig(TypedConfig): + name: Literal["warmup_cos_rlp"] = "warmup_cos_rlp" + + warmup_steps: int | None = None + warmup_epochs: int | None = None + max_steps: int | None = None + max_epochs: int | None = None + warmup_start_lr_factor: float = 0.0 + min_lr_factor: float = 1.0e-2 + last_step: int = -1 + should_restart: bool = False + + rlp: RLPConfig + + @override + def __post_init__(self): + super().__post_init__() + + assert self.rlp.warmup is None, "RLP warmup is not supported" + + +LRSchedulerConfig: TypeAlias = Annotated[ + RLPConfig | WarmupCosRLPConfig, Field(discriminator="name") +] + + +class FreezeConfig(TypedConfig): + backbone: bool = False + """Should the backbone be frozen?""" + embedding: bool = False + """Should the embedding layer be frozen?""" + + backbone_bases: bool = False + """Should the basis functions in the backbone be frozen?""" + backbone_interaction_layers: list[int] | None = None + """Which interaction layers, if any, in the backbone should be frozen?""" + backbone_output_layers: list[int] | None = None + """Which output layers, if any, in the backbone should be frozen?""" + + parameter_patterns: list[str] = [] + """List of parameter patterns to freeze""" + + +class ParamSpecificOptimizerConfig(TypedConfig): + name: str | None = None + """The name of the parameter group for this config""" + + paremeter_patterns: list[str] = [] + """List of parameter patterns to match for this config""" + + optimizer: OptimizerConfig | None = None + """ + The optimizer config for this parameter group. + If None, the default optimizer will be used. + """ + + lr_scheduler: LRSchedulerConfig | None = None + """ + The learning rate scheduler config for this parameter group. + If None, the default learning rate scheduler will be used. + """ + + +class CheckpointLoadConfig(TypedConfig): + ignored_key_patterns: list[str] = [] + """Patterns to ignore when loading the checkpoint""" + + ignored_missing_keys: list[str] = [] + """Keys to ignore if they are missing in the checkpoint""" + + ignored_unexpected_keys: list[str] = [] + """Keys to ignore if they are unexpected in the checkpoint""" + + reset_embeddings: bool = False + """ + If true, it will reset the embeddings to the initial state + after loading the checkpoint + """ + + +class CheckpointBestConfig(TypedConfig): + monitor: str | None = None + """ + The metric to monitor for checkpointing. + If None, the primary metric will be used. + """ + mode: Literal["min", "max"] | None = None + """ + The mode for the metric to monitor for checkpointing. + If None, the primary metric mode will be used. + """ + + +class EarlyStoppingConfig(TypedConfig): + monitor: str | None = None + """ + The metric to monitor for early stopping. + If None, the primary metric will be used. + """ + mode: Literal["min", "max"] | None = None + """ + The mode for the metric to monitor for early stopping. + If None, the primary metric mode will be used. + """ + + patience: int + """ + Number of epochs with no improvement after which training will be stopped. + """ + + min_delta: float = 1.0e-8 + """ + Minimum change in the monitored quantity to qualify as an improvement. + """ + min_lr: float | None = None + """ + Minimum learning rate. If the learning rate of the model is less than this value, + the training will be stopped. + """ + strict: bool = True + """ + Whether to enforce that the monitored quantity must improve by at least `min_delta` + to qualify as an improvement. + """ + + +class BinaryClassificationTargetConfig(TypedConfig): + name: str + """The name of the target""" + num_classes: int + """The number of classes for the target""" + + pos_weight: float | None = None + """The positive weight for the target""" + + @override + def __post_init__(self): + super().__post_init__() + + if self.num_classes != 2: + raise ValueError( + f"Binary classification target {self.name} has {self.num_classes} classes" + ) + + +class MulticlassClassificationTargetConfig(TypedConfig): + name: str + """The name of the target""" + num_classes: int + """The number of classes for the target""" + + class_weights: list[float] | None = None + """The class weights for the target""" + dropout: float | None = None + """The dropout probability to use before the output layer""" + + +class PrimaryMetricConfig(TypedConfig): + name: str + """The name of the primary metric""" + mode: Literal["min", "max"] + """ + The mode of the primary metric: + - "min" for metrics that should be minimized (e.g., loss) + - "max" for metrics that should be maximized (e.g., accuracy) + """ + + +class TestConfig(TypedConfig): + save_checkpoint_base_dir: Path | None = None + """Where to save the checkpoint information for this run (or None to disable)""" + + save_results_base_dir: Path | None = None + """Where to save the results for this run (or None to disable)""" + + +class FinetuneLmdbDatasetConfig(CommonDatasetConfig): + name: Literal["lmdb"] = "lmdb" + + src: Path + """Path to the LMDB file or directory containing LMDB files.""" + + metadata_path: Path | None = None + """Path to the metadata npz file containing the number of atoms in each structure.""" + + def __post_init__(self): + super().__post_init__() + + # If metadata_path is not provided, assume it is src/metadata.npz + if self.metadata_path is None: + self.metadata_path = self.src / "metadata.npz" + + def create_dataset(self): + return LmdbDataset(src=self.src, metadata_path=self.metadata_path) + + +class FinetunePDBBindDatasetConfig(PDBBindConfig, CommonDatasetConfig): + name: Literal["pdbbind"] = "pdbbind" + + def create_dataset(self): + return PDBBindDataset(task=self.task, split=self.split) + + +FinetuneDatasetConfig: TypeAlias = Annotated[ + FinetuneLmdbDatasetConfig | FinetunePDBBindDatasetConfig, + Field(discriminator="name"), +] + + +class FinetuneConfigBase(BaseConfig): + train_dataset: FinetuneDatasetConfig | None = None + """Configuration for the train dataset""" + val_dataset: FinetuneDatasetConfig | None = None + """Configuration for the val dataset""" + test_dataset: FinetuneDatasetConfig | None = None + """Configuration for the test dataset""" + + optimizer: OptimizerConfig + """Optimizer to use.""" + lr_scheduler: LRSchedulerConfig | None = None + """Learning rate scheduler configuration. If None, no learning rate scheduler is used.""" + + activation: Literal[ + "scaled_silu", + "scaled_swish", + "silu", + "swish", + ] = "scaled_silu" + """Activation function to use.""" + + embedding: EmbeddingConfig = EmbeddingConfig( + num_elements=BackboneConfig.base().num_elements, + embedding_size=BackboneConfig.base().emb_size_atom, + ) + """Configuration for the embedding layer.""" + backbone: BackboneConfig + """Configuration for the backbone.""" + output: OutputConfig = OutputConfig(num_mlps=5) + """Configuration for the output head.""" + + batch_size: int + """Batch size to use.""" + eval_batch_size: int | None = None + """Batch size to use for evaluation. If None, use the same as batch_size.""" + num_workers: int = 8 + """Number of workers to use for data loading.""" + pin_memory: bool = True + """Whether to use pin memory for data loading.""" + + @property + def activation_cls(self): + match self.activation: + case "scaled_silu" | "scaled_swish": + return ScaledSiLU + case "silu" | "swish": + return nn.SiLU + case None: + return nn.Identity + case _: + raise NotImplementedError( + f"Activation {self.activation} is not implemented" + ) + + primary_metric: PrimaryMetricConfig + """Primary metric to use for early stopping and checkpointing""" + early_stopping: EarlyStoppingConfig | None = None + """Configuration for early stopping""" + ckpt_best: CheckpointBestConfig | None = CheckpointBestConfig() + """Configuration for saving the best checkpoint""" + + test: TestConfig | None = None + """Configuration for test stage""" + + graph_scalar_targets: list[str] = [] + """List of graph scalar targets (e.g., energy)""" + graph_classification_targets: list[ + BinaryClassificationTargetConfig | MulticlassClassificationTargetConfig + ] = [] + """List of graph classification targets (e.g., is_metal)""" + node_vector_targets: list[str] = [] + """List of node vector targets (e.g., force)""" + + @property + def regression_targets(self): + """List of all regression targets, i.e., graph scalar and node vector targets""" + return self.node_vector_targets + self.graph_scalar_targets + + @property + def all_targets(self): + """List of all targets, i.e., graph scalar, graph classification, and node vector targets""" + return ( + self.node_vector_targets + + self.graph_scalar_targets + + [target.name for target in self.graph_classification_targets] + ) + + graph_scalar_loss_coefficient_default: float = 1.0 + """Default loss coefficient for graph scalar targets, if not specified in `graph_scalar_loss_coefficients`""" + graph_classification_loss_coefficient_default: float = 1.0 + """Default loss coefficient for graph classification targets, if not specified in `graph_classification_loss_coefficients`""" + node_vector_loss_coefficient_default: float = 1.0 + """Default loss coefficient for node vector targets, if not specified in `node_vector_loss_coefficients`""" + graph_scalar_loss_coefficients: dict[str, float] = {} + """Loss coefficients for graph scalar targets""" + graph_classification_loss_coefficients: dict[str, float] = {} + """Loss coefficients for graph classification targets""" + node_vector_loss_coefficients: dict[str, float] = {} + """Loss coefficients for node vector targets""" + + graph_scalar_reduction_default: Literal["sum", "mean", "max"] = "sum" + """Default reduction method, if not specified in `graph_scalar_reduction`, for computing graph scalar targets from each node's scalar prediction""" + graph_classification_reduction_default: Literal["sum", "mean", "max"] = "sum" + """Default reduction method, if method fornot specified in `graph_classification_reduction`, graph classification targets from each node's classification prediction""" + node_vector_reduction_default: Literal["sum", "mean", "max"] = "sum" + """Default reduction method, if not specified in `node_vector_reduction`, for computing node vector targets from each edge's vector prediction""" + + graph_scalar_reduction: dict[str, Literal["sum", "mean", "max"]] = {} + """Reduction methods for computing graph scalar targets from each node's scalar prediction""" + graph_classification_reduction: dict[str, Literal["sum", "mean", "max"]] = {} + """Reduction methods for computing graph classification targets from each node's classification prediction""" + node_vector_reduction: dict[str, Literal["sum", "mean", "max"]] = {} + """Reduction methods for computing node vector targets from each edge's vector prediction""" + + normalization: dict[str, NormalizationConfig] = {} + """Normalization parameters for each target""" + + parameter_specific_optimizers: list[ParamSpecificOptimizerConfig] | None = None + """Configuration for parameter-specific optimizers""" + + use_balanced_batch_sampler: bool = False + """ + Whether to use balanced batch sampler. + + This balances the batches across all distributed nodes (i.e., GPUs, TPUs, nodes, etc.) + to ensure that each batch has an equal number of **atoms** across all nodes. + """ + + freeze: FreezeConfig = FreezeConfig() + """Configuration for freezing parameters""" + + ckpt_load: CheckpointLoadConfig = CheckpointLoadConfig() + """Configuration for behavior when loading checkpoints""" + + shuffle_val: bool = False + """Whether to shuffle the validation set""" + shuffle_test: bool = False + """Whether to shuffle the test set""" + + metrics: MetricsConfig = MetricsConfig() + """Configuration for metrics""" + + ema: EMAConfig | None = None + """Configuration for exponential moving average""" + + @override + def __post_init__(self): + super().__post_init__() + + if self.use_balanced_batch_sampler: + assert not self.trainer.use_distributed_sampler, "config.trainer.use_distributed_sampler must be False when using balanced batch sampler" + + +TConfig = TypeVar("TConfig", bound=FinetuneConfigBase) + + +class OutputHeadInput(TypedDict): + data: BaseData + backbone_output: GOCBackboneOutput + + +class GraphScalarOutputHead(Base[TConfig], nn.Module, Generic[TConfig]): + @override + def __init__( + self, + config: TConfig, + reduction: str | None = None, + ): + super().__init__(config) + + if reduction is None: + reduction = self.config.graph_scalar_reduction_default + + self.out_mlp = self.mlp( + ([self.config.backbone.emb_size_atom] * self.config.output.num_mlps) + + [self.config.backbone.num_targets], + activation=self.config.activation_cls, + ) + self.reduction = reduction + + @override + def forward( + self, + input: OutputHeadInput, + *, + scale: torch.Tensor | None = None, + shift: torch.Tensor | None = None, + ) -> torch.Tensor: + data = input["data"] + backbone_output = input["backbone_output"] + + n_molecules = int(torch.max(data.batch).item() + 1) + + output = self.out_mlp(backbone_output["energy"]) # (n_atoms, 1) + if scale is not None: + output = output * scale + if shift is not None: + output = output + shift + + output = scatter( + output, + data.batch, + dim=0, + dim_size=n_molecules, + reduce=self.reduction, + ) # (bsz, 1) + output = rearrange(output, "b 1 -> b") + return output + + +class GraphBinaryClassificationOutputHead(Base[TConfig], nn.Module, Generic[TConfig]): + @override + def __init__( + self, + config: TConfig, + classification_config: BinaryClassificationTargetConfig, + reduction: str | None = None, + ): + super().__init__(config) + + assert ( + classification_config.num_classes == 2 + ), "Only binary classification supported" + + if reduction is None: + reduction = self.config.graph_scalar_reduction_default + + self.out_mlp = self.mlp( + ([self.config.backbone.emb_size_atom] * self.config.output.num_mlps) + [1], + activation=self.config.activation_cls, + ) + self.classification_config = classification_config + self.reduction = reduction + + @override + def forward(self, input: OutputHeadInput) -> torch.Tensor: + data = input["data"] + backbone_output = input["backbone_output"] + + n_molecules = int(torch.max(data.batch).item() + 1) + + output = self.out_mlp(backbone_output["energy"]) # (n, num_classes) + output = scatter( + output, + data.batch, + dim=0, + dim_size=n_molecules, + reduce=self.reduction, + ) # (bsz, num_classes) + output = rearrange(output, "b 1 -> b") + return output + + +class GraphMulticlassClassificationOutputHead( + Base[TConfig], nn.Module, Generic[TConfig] +): + @override + def __init__( + self, + config: TConfig, + classification_config: MulticlassClassificationTargetConfig, + reduction: str | None = None, + ): + super().__init__(config) + + if reduction is None: + reduction = self.config.graph_scalar_reduction_default + + self.out_mlp = self.mlp( + ([self.config.backbone.emb_size_atom] * self.config.output.num_mlps) + + [classification_config.num_classes], + activation=self.config.activation_cls, + ) + self.classification_config = classification_config + self.reduction = reduction + + self.dropout = None + if classification_config.dropout: + self.dropout = nn.Dropout(classification_config.dropout) + + @override + def forward(self, input: OutputHeadInput) -> torch.Tensor: + data = input["data"] + n_molecules = int(torch.max(data.batch).item() + 1) + + x = input["backbone_output"]["energy"] + if self.dropout is not None: + x = self.dropout(x) + + x = self.out_mlp(x) # (n, num_classes) + x = scatter( + x, + data.batch, + dim=0, + dim_size=n_molecules, + reduce=self.reduction, + ) # (bsz, num_classes) + return x + + +class NodeVectorOutputHead(Base[TConfig], nn.Module, Generic[TConfig]): + @override + def __init__( + self, + config: TConfig, + reduction: str | None = None, + ): + super().__init__(config) + + if reduction is None: + reduction = self.config.graph_scalar_reduction_default + + self.out_mlp = self.mlp( + ([self.config.backbone.emb_size_edge] * self.config.output.num_mlps) + + [self.config.backbone.num_targets], + activation=self.config.activation_cls, + ) + self.reduction = reduction + + @override + def forward(self, input: OutputHeadInput) -> torch.Tensor: + data = input["data"] + backbone_output = input["backbone_output"] + + n_atoms = data.atomic_numbers.shape[0] + + output = self.out_mlp(backbone_output["forces"]) + output = output * backbone_output["V_st"] # (n_edges, 3) + output = scatter( + output, + backbone_output["idx_t"], + dim=0, + dim_size=n_atoms, + reduce=self.reduction, + ) + return output + + +class FinetuneModelBase(LightningModuleBase[TConfig], Generic[TConfig]): + @abstractmethod + def metric_prefix(self) -> str: ... + + @override + def on_test_end(self): + super().on_test_end() + + match self.config.test: + case TestConfig(save_checkpoint_base_dir=Path() as base): + # The save dir for this run should be base/{metric_prefix()}/{config.name}-{config.id} + base = base / self.metric_prefix() + base.mkdir(parents=True, exist_ok=True) + save_dir = base / f"{self.config.name}-{self.config.id}" + if save_dir.exists(): + i = 0 + while ( + save_dir := base / f"{self.config.name}-{self.config.id}-{i}" + ).exists(): + i += 1 + save_dir.mkdir(parents=True, exist_ok=True) + + # Get ckpt path from config + ckpt_path = self.config.meta.get("ckpt_path") + if ckpt_path is None: + raise ValueError( + f"Checkpoint path not found in meta: {self.config.meta=}" + ) + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + raise ValueError(f"Checkpoint path does not exist: {ckpt_path=}") + + # Create a symlink to the checkpoint + symlink_path = base / f"pretrained-{ckpt_path.name}" + if symlink_path.exists(): + raise ValueError(f"Symlink path already exists: {symlink_path=}") + symlink_path.symlink_to(ckpt_path) + + # Also create an ckptpath.txt file that contains the original ckpt path + _ = (base / "ckptpath.txt").write_text( + str(ckpt_path.resolve().absolute()) + ) + + log.critical(f"Saving checkpoint information to {save_dir}") + case _: + pass + + def primary_metric(self, split: Literal["train", "val", "test"] | None = "val"): + config = self.config.primary_metric + metric = f"{self.metric_prefix()}/{config.name}" + if split is not None: + metric = f"{split}/{metric}" + return metric, config.mode + + def _set_rlp_config_monitors(self): + match self.config.lr_scheduler: + case RLPConfig(monitor=None) as rlp_config: + rlp_config.monitor, rlp_config.mode = self.primary_metric() + case WarmupCosRLPConfig(rlp=RLPConfig(monitor=None) as rlp_config): + rlp_config.monitor, rlp_config.mode = self.primary_metric() + case _: + pass + + def validate_config(self, config: TConfig): + assert config.activation.lower() == config.backbone.activation.lower() + + assert config.embedding.num_elements == config.backbone.num_elements + assert config.embedding.embedding_size == config.backbone.emb_size_atom + + assert config.all_targets, f"No targets specified, {config.all_targets=}" + for a, b in itertools.combinations( + [ + config.graph_scalar_targets, + [target.name for target in config.graph_classification_targets], + config.node_vector_targets, + ], + 2, + ): + assert ( + set(a) & set(b) == set() + ), f"Targets must be disjoint, but they are not: {a} and {b}" + # config.targets = config.graph_scalar_targets + config.node_vector_targets + + if config.graph_scalar_loss_coefficients: + assert set(config.graph_scalar_loss_coefficients.keys()).issubset( + set(config.graph_scalar_targets) + ), ( + f"Loss coefficients must correspond to graph scalar targets, but they " + f"do not: {config.graph_scalar_loss_coefficients.keys()=} vs " + f"{config.graph_scalar_targets=}" + ) + + if config.node_vector_loss_coefficients: + assert set(config.node_vector_loss_coefficients.keys()).issubset( + set(config.node_vector_targets) + ), ( + f"Loss coefficients must correspond to node vector targets, but they " + f"do not: {config.node_vector_loss_coefficients.keys()=} vs " + f"{config.node_vector_targets=}" + ) + + def _construct_backbone(self): + log.critical("Using regular backbone") + + backbone = GemNetOCBackbone(self.config.backbone, **dict(self.config.backbone)) + + return backbone + + def metrics_provider( + self, + prop: str, + batch: BaseData, + preds: dict[str, torch.Tensor], + ) -> MetricPair | None: + if (pred := preds.get(prop)) is None or ( + target := getattr(batch, prop, None) + ) is None: + return None + + if ( + self.config.normalization + and (norm := self.config.normalization.get(prop)) is not None + ): + # Denormalize the predictions and targets + pred = pred * norm.std + norm.mean + target = target * norm.std + norm.mean + + return MetricPair(predicted=pred, ground_truth=target) + + @override + def __init__(self, hparams: TConfig): + self.validate_config(hparams) + super().__init__(hparams) + + # Set up callbacks + if (ema := self.config.ema) is not None: + self.register_callback(lambda: ema.construct_callback()) + + self._set_rlp_config_monitors() + + self.embedding = nn.Embedding( + num_embeddings=self.config.embedding.num_elements, + embedding_dim=self.config.embedding.embedding_size, + ) + + self.backbone = self._construct_backbone() + self.register_shared_parameters(self.backbone.shared_parameters) + + self.construct_output_heads() + + self.train_metrics = FinetuneMetrics( + self.config.metrics, + self.metrics_provider, + self.config.graph_scalar_targets, + self.config.graph_classification_targets, + self.config.node_vector_targets, + ) + self.val_metrics = FinetuneMetrics( + self.config.metrics, + self.metrics_provider, + self.config.graph_scalar_targets, + self.config.graph_classification_targets, + self.config.node_vector_targets, + ) + self.test_metrics = FinetuneMetrics( + self.config.metrics, + self.metrics_provider, + self.config.graph_scalar_targets, + self.config.graph_classification_targets, + self.config.node_vector_targets, + ) + + # Sanity check: ensure all named_parameters have requires_grad=True, + # otherwise add them to ignored_parameters. + self.ignored_parameters = set[nn.Parameter]() + for name, param in self.named_parameters(): + if param.requires_grad: + continue + self.ignored_parameters.add(param) + log.info(f"Adding {name} to ignored_parameters") + + self.process_freezing() + + if (ckpt_best := self.config.ckpt_best) is not None: + if (monitor := ckpt_best.monitor) is None: + monitor, mode = self.primary_metric() + else: + if (mode := ckpt_best.mode) is None: + mode = "min" + + self.register_callback(lambda: ModelCheckpoint(monitor=monitor, mode=mode)) + + if (early_stopping := self.config.early_stopping) is not None: + if (monitor := early_stopping.monitor) is None: + monitor, mode = self.primary_metric() + else: + if (mode := early_stopping.mode) is None: + mode = "min" + + self.register_callback( + lambda: EarlyStoppingWithMinLR( + monitor=monitor, + mode=mode, + patience=early_stopping.patience, + min_delta=early_stopping.min_delta, + min_lr=early_stopping.min_lr, + strict=early_stopping.strict, + ) + ) + + for cls_target in self.config.graph_classification_targets: + match cls_target: + case MulticlassClassificationTargetConfig( + class_weights=class_weights + ) if class_weights: + self.register_buffer( + f"{cls_target.name}_class_weights", + torch.tensor(class_weights, dtype=torch.float), + persistent=False, + ) + case _: + pass + + def freeze_parameters(self, parameters: Iterable[nn.Parameter], *, name: str): + n_params = 0 + for param in parameters: + if param in self.ignored_parameters: + continue + + param.requires_grad = False + n_params += param.numel() + log.critical(f"Freezing {n_params} parameters in {name}") + + def named_parameters_matching_patterns(self, patterns: list[str]): + for name, param in self.named_parameters(): + if param in self.ignored_parameters: + continue + if ( + matching_pattern := next( + (pattern for pattern in patterns if fnmatch.fnmatch(name, pattern)), + None, + ) + ) is None: + continue + + yield name, param, matching_pattern + + def process_freezing(self): + if self.config.freeze.backbone: + self.freeze_parameters(self.backbone.parameters(), name="backbone") + + if self.config.freeze.embedding: + self.freeze_parameters(self.embedding.parameters(), name="embedding") + + if self.config.freeze.backbone_interaction_layers: + for layer_idx in self.config.freeze.backbone_interaction_layers: + self.freeze_parameters( + self.backbone.int_blocks[layer_idx].parameters(), + name=f"backbone.int_blocks[{layer_idx}]", + ) + + if self.config.freeze.backbone_output_layers: + for layer_idx in self.config.freeze.backbone_output_layers: + self.freeze_parameters( + self.backbone.out_blocks[layer_idx].parameters(), + name=f"backbone.out_blocks[{layer_idx}]", + ) + + if self.config.freeze.backbone_bases: + self.freeze_parameters( + self.backbone.bases.parameters(), name="backbone.bases" + ) + + if self.config.freeze.parameter_patterns: + for ( + name, + param, + matching_pattern, + ) in self.named_parameters_matching_patterns( + self.config.freeze.parameter_patterns + ): + param.requires_grad = False + log.info(f"Freezing {name} (pattern: {matching_pattern})") + + all_parameters = [ + param for param in self.parameters() if param not in self.ignored_parameters + ] + num_frozen = sum( + param.numel() for param in all_parameters if not param.requires_grad + ) + num_train = sum( + param.numel() for param in all_parameters if param.requires_grad + ) + num_total = sum(param.numel() for param in all_parameters) + percent_frozen = num_frozen / num_total * 100 + log.critical( + f"Freezing {num_frozen:,} parameters ({percent_frozen:.2f}%) out of " + f"{num_total:,} total parameters ({num_train:,} trainable)" + ) + + def construct_graph_scalar_output_head(self, target: str) -> nn.Module: + return GraphScalarOutputHead( + self.config, + reduction=self.config.graph_scalar_reduction.get( + target, self.config.graph_scalar_reduction_default + ), + ) + + def construct_graph_classification_output_head( + self, + target: BinaryClassificationTargetConfig | MulticlassClassificationTargetConfig, + ) -> nn.Module: + match target: + case BinaryClassificationTargetConfig(): + return GraphBinaryClassificationOutputHead( + self.config, + target, + reduction=self.config.graph_classification_reduction.get( + target.name, self.config.graph_classification_reduction_default + ), + ) + case MulticlassClassificationTargetConfig(): + return GraphMulticlassClassificationOutputHead( + self.config, + target, + reduction=self.config.graph_classification_reduction.get( + target.name, self.config.graph_classification_reduction_default + ), + ) + case _: + raise ValueError(f"Invalid target: {target}") + + def construct_node_vector_output_head(self, target: str) -> nn.Module: + return NodeVectorOutputHead( + self.config, + reduction=self.config.node_vector_reduction.get( + target, self.config.node_vector_reduction_default + ), + ) + + def construct_output_heads(self): + self.graph_outputs = TypedModuleDict( + { + target: self.construct_graph_scalar_output_head(target) + for target in self.config.graph_scalar_targets + }, + key_prefix="ft_mlp_", + ) + self.graph_classification_outputs = TypedModuleDict( + { + target.name: self.construct_graph_classification_output_head(target) + for target in self.config.graph_classification_targets + }, + key_prefix="ft_mlp_", + ) + self.node_outputs = TypedModuleDict( + { + target: self.construct_node_vector_output_head(target) + for target in self.config.node_vector_targets + }, + key_prefix="ft_mlp_", + ) + + def load_backbone_state_dict( + self, + *, + backbone: Mapping[str, Any], + embedding: Mapping[str, Any], + strict: bool = True, + ): + ignored_key_patterns = self.config.ckpt_load.ignored_key_patterns + # If we're dumping the backbone's force out heads, then we need to ignore + # the unexpected keys for the force out MLPs and force out heads. + if ( + not self.config.backbone.regress_forces + or not self.config.backbone.direct_forces + ): + ignored_key_patterns.append("out_mlp_F.*") + for block_idx in range(self.config.backbone.num_blocks + 1): + ignored_key_patterns.append(f"out_blocks.{block_idx}.scale_rbf_F.*") + ignored_key_patterns.append(f"out_blocks.{block_idx}.dense_rbf_F.*") + ignored_key_patterns.append(f"out_blocks.{block_idx}.seq_forces.*") + + load_state_dict( + self.backbone, + backbone, + strict=strict, + ignored_key_patterns=ignored_key_patterns, + ignored_missing_keys=self.config.ckpt_load.ignored_missing_keys, + ignored_unexpected_keys=self.config.ckpt_load.ignored_unexpected_keys, + ) + if not self.config.ckpt_load.reset_embeddings: + load_state_dict(self.embedding, embedding, strict=strict) + log.critical("Loaded backbone state dict (backbone and embedding).") + + @override + def forward(self, data: BaseData): + atomic_numbers = data.atomic_numbers - 1 + h = self.embedding(atomic_numbers) # (N, d_model) + out = cast(GOCBackboneOutput, self.backbone(data, h=h)) + + output_head_input: OutputHeadInput = { + "backbone_output": out, + "data": data, + } + + preds = { + **{ + target: module(output_head_input) + for target, module in self.graph_outputs.items() + }, + **{ + target: module(output_head_input) + for target, module in self.graph_classification_outputs.items() + }, + **{ + target: module(output_head_input) + for target, module in self.node_outputs.items() + }, + } + return preds + + def compute_losses(self, batch: BaseData, preds: dict[str, torch.Tensor]): + losses: list[torch.Tensor] = [] + + for target in self.config.graph_scalar_targets: + loss = F.l1_loss(preds[target], batch[target]) + self.log(f"{target}_loss", loss) + + coef = self.config.graph_scalar_loss_coefficients.get( + target, self.config.graph_scalar_loss_coefficient_default + ) + loss = coef * loss + self.log(f"{target}_loss_scaled", loss) + + losses.append(loss) + + for target in self.config.graph_classification_targets: + match target: + case BinaryClassificationTargetConfig(): + y_input = preds[target.name] + y_target = batch[target.name].float() + pos_weight = None + if target.pos_weight is not None: + pos_weight = y_input.new_tensor(target.pos_weight) + loss = F.binary_cross_entropy_with_logits( + y_input, y_target, reduction="sum", pos_weight=pos_weight + ) + case MulticlassClassificationTargetConfig(): + weight = None + if target.class_weights: + weight = self.get_buffer(f"{target.name}_class_weights") + + loss = F.cross_entropy( + preds[target.name], + batch[target.name].long(), + weight=weight, + reduction="sum", + ) + case _: + raise ValueError(f"Unknown target type: {target}") + self.log(f"{target.name}_loss", loss) + + coef = self.config.graph_classification_loss_coefficients.get( + target.name, self.config.graph_classification_loss_coefficient_default + ) + loss = coef * loss + self.log(f"{target.name}_loss_scaled", loss) + + losses.append(loss) + + for target in self.config.node_vector_targets: + assert preds[target].shape[-1] == 3 + loss = F.pairwise_distance(preds[target], batch[target], p=2.0).mean() + self.log(f"{target}_loss", loss) + + coef = self.config.node_vector_loss_coefficients.get( + target, self.config.node_vector_loss_coefficient_default + ) + loss = coef * loss + self.log(f"{target}_loss_scaled", loss) + + losses.append(loss) + + loss = sum(losses) + self.log("loss", loss) + + return loss + + def _rlp_metric(self, config: RLPConfig): + monitor = config.monitor + assert monitor is not None, "RLP monitor must be specified." + + metric_prefix = f"val/{self.metric_prefix()}/" + assert monitor.startswith( + metric_prefix + ), f"RLP {monitor=} must start with {metric_prefix}" + monitor = monitor[len(metric_prefix) :] + + if ( + monitor.endswith("_mae") + and (mae_metric := self.val_metrics.maes.get(monitor[: -len("_mae")])) + is not None + ): + return mae_metric + + if ( + monitor.endswith("_balanced_accuracy") + and ( + cls_metric := self.val_metrics.cls_metrics.get( + monitor[: -len("_balanced_accuracy")] + ) + ) + is not None + ): + return cls_metric + + avail_mae_metrics = list(self.val_metrics.maes.keys()) + avail_cls_metrics = list(self.val_metrics.cls_metrics.keys()) + raise ValueError( + f"RLP monitor {monitor} not found in metrics. " + f"Available MAE metrics: {avail_mae_metrics}. " + f"Available classification metrics: {avail_cls_metrics}" + ) + + def _cos_rlp_schedulers(self): + if (lr_schedulers := self.lr_schedulers()) is None: + log.warning("No LR scheduler found.") + return + + if not isinstance(lr_schedulers, list): + lr_schedulers = [lr_schedulers] + + for scheduler in lr_schedulers: + if isinstance(scheduler, PerParamGroupLinearWarmupCosineAnnealingRLPLR): + yield scheduler + + def _on_validation_epoch_end_cos_rlp(self, config: WarmupCosRLPConfig): + rlp_monitor = self._rlp_metric(config.rlp) + log.info(f"LR scheduler metrics: {rlp_monitor}") + + metric_value: torch.Tensor | None = None + for scheduler in self._cos_rlp_schedulers(): + if scheduler.is_in_rlp_stage(self.global_step): + if metric_value is None: + metric_value = rlp_monitor.compute() + + log.info(f"LR scheduler is in RLP mode. RLP metric: {metric_value}") + scheduler.rlp_step(metric_value) + + def _on_train_batch_start_cos_rlp(self): + for scheduler in self._cos_rlp_schedulers(): + scheduler.on_new_step(self.global_step) + + @override + def on_train_batch_start(self, batch: BaseData, batch_idx: int): + match self.config.lr_scheduler: + case WarmupCosRLPConfig(): + self._on_train_batch_start_cos_rlp() + case _: + pass + + @override + def on_validation_epoch_end(self): + match self.config.lr_scheduler: + case WarmupCosRLPConfig() as config: + self._on_validation_epoch_end_cos_rlp(config) + case _: + pass + + @override + def training_step(self, batch: BaseData, batch_idx: int): + with self.log_context(prefix=f"train/{self.metric_prefix()}/"): + preds = self(batch) + + loss = self.compute_losses(batch, preds) + self.log_dict(self.train_metrics(batch, preds)) + + return loss + + @override + def validation_step(self, batch: BaseData, batch_idx: int): + with self.log_context(prefix=f"val/{self.metric_prefix()}/"): + preds = self(batch) + + self.log_dict(self.val_metrics(batch, preds)) + + @override + def test_step(self, batch: BaseData, batch_idx: int): + with self.log_context(prefix=f"test/{self.metric_prefix()}/"): + preds = self(batch) + + self.log_dict(self.test_metrics(batch, preds)) + + def outhead_parameters(self): + head_params = ( + list(self.graph_outputs.parameters()) + + list(self.node_outputs.parameters()) + + list(self.graph_classification_outputs.parameters()) + ) + return head_params + + def backbone_outhead_parameters( + self, + ): + main_params = list(self.parameters()) + head_params = self.outhead_parameters() + head_params_set = set(head_params) + main_params = [p for p in main_params if p not in head_params_set] + return main_params, head_params + + @override + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None): + match self.config.lr_scheduler: + case RLPConfig(warmup=RLPWarmupConfig()): + lr_scheduler = self.lr_schedulers() + assert isinstance(lr_scheduler, GradualWarmupScheduler) + if not lr_scheduler.finished: + lr_scheduler.step() + case _: + pass + + super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure) + + def split_parameters(self, pattern_lists: list[list[str]]): + all_parameters = list(self.parameters()) + + parameters: list[list[torch.nn.Parameter]] = [] + for patterns in pattern_lists: + matching = [ + p for _, p, _ in self.named_parameters_matching_patterns(patterns) + ] + parameters.append(matching) + # remove matching parameters from all_parameters + all_parameters = [ + p for p in all_parameters if all(p is not m for m in matching) + ] + + return parameters, all_parameters + + def _cos_annealing_hparams( + self, lr_config: WarmupCosRLPConfig, *, lr_initial: float + ): + if (warmup_steps := lr_config.warmup_steps) is None: + if warmup_epochs := lr_config.warmup_epochs: + assert warmup_epochs >= 0, f"Invalid warmup_epochs: {warmup_epochs}" + _ = self.trainer.estimated_stepping_batches # make sure dataloaders are loaded for self.trainer.num_training_batches + num_steps_per_epoch = math.ceil( + self.trainer.num_training_batches + / self.trainer.accumulate_grad_batches + ) + warmup_steps = warmup_epochs * num_steps_per_epoch + else: + warmup_steps = 0 + log.critical(f"Computed warmup_steps: {warmup_steps}") + + if not (max_steps := lr_config.max_steps): + if max_epochs := lr_config.max_epochs: + _ = self.trainer.estimated_stepping_batches # make sure dataloaders are loaded for self.trainer.num_training_batches + num_steps_per_epoch = math.ceil( + self.trainer.num_training_batches + / self.trainer.accumulate_grad_batches + ) + max_steps = max_epochs * num_steps_per_epoch + else: + max_steps = self.trainer.estimated_stepping_batches + assert math.isfinite(max_steps), f"{max_steps=} is not finite" + max_steps = int(max_steps) + + log.critical(f"Computed max_steps: {max_steps}") + + assert ( + lr_config.min_lr_factor > 0 and lr_config.min_lr_factor <= 1 + ), f"Invalid {lr_config.min_lr_factor=}" + min_lr = lr_initial * lr_config.min_lr_factor + + assert ( + lr_config.warmup_start_lr_factor > 0 + and lr_config.warmup_start_lr_factor <= 1 + ), f"Invalid {lr_config.warmup_start_lr_factor=}" + warmup_start_lr = lr_initial * lr_config.warmup_start_lr_factor + + lr_scheduler_hparams = dict( + warmup_epochs=warmup_steps, + max_epochs=max_steps, + warmup_start_lr=warmup_start_lr, + eta_min=min_lr, + should_restart=lr_config.should_restart, + ) + + return lr_scheduler_hparams + + def _construct_lr_scheduler( + self, optimizer: torch.optim.Optimizer, config: RLPConfig + ): + assert config.monitor is not None, f"{config=}" + assert config.mode is not None, f"{config=}" + + lr_scheduler = ReduceLROnPlateau( + optimizer, + mode=config.mode, + factor=config.factor, + threshold=config.threshold, + threshold_mode=config.threshold_mode, + patience=config.patience, + cooldown=config.cooldown, + min_lr=config.min_lr, + eps=config.eps, + verbose=True, + ) + if config.warmup is not None: + optim_lr = float(optimizer.param_groups[0]["lr"]) + warmup_start_lr = optim_lr * config.warmup.start_lr_factor + + lr_scheduler = GradualWarmupScheduler( + optimizer, + warmup_start_lr=warmup_start_lr, + warmup_steps=config.warmup.steps, + after_scheduler=lr_scheduler, + ) + return { + "scheduler": lr_scheduler, + "monitor": config.monitor, + "interval": config.interval, + "frequency": config.frequency, + "strict": False, + "reduce_on_plateau": True, + } + else: + return { + "scheduler": lr_scheduler, + "monitor": config.monitor, + "interval": config.interval, + "frequency": config.frequency, + "strict": True, + } + + def configure_optimizers_param_specific_optimizers( + self, configs: list[ParamSpecificOptimizerConfig] + ): + params_list, rest_params = self.split_parameters( + [c.paremeter_patterns for c in configs] + ) + optimizer = optimizer_from_config( + [ + *( + ( + self.config.optimizer if c.optimizer is None else c.optimizer, + params, + c.name or ",".join(c.paremeter_patterns), + ) + for c, params in zip(configs, params_list) + ), + (self.config.optimizer, rest_params, "rest"), + ], + base=self.config.optimizer, + ) + + out: dict[str, Any] = { + "optimizer": optimizer, + } + if (lr_config := self.config.lr_scheduler) is None: + return out + + match lr_config: + case RLPConfig(): + assert all( + c.lr_scheduler is None for c in configs + ), f"lr_scheduler is not None for some configs: {configs=}" + + if ( + lr_scheduler := self._construct_lr_scheduler(optimizer, lr_config) + ) is not None: + out["lr_scheduler"] = lr_scheduler + case WarmupCosRLPConfig(): + param_group_lr_scheduler_settings = [ + *( + self._cos_annealing_hparams( + ( + lr_config + if c.lr_scheduler is None + or not isinstance(c.lr_scheduler, WarmupCosRLPConfig) + else c.lr_scheduler + ), + lr_initial=param_group["lr"], + ) + for c, param_group in zip(configs, optimizer.param_groups[:-1]) + ), + self._cos_annealing_hparams( + lr_config, lr_initial=optimizer.param_groups[-1]["lr"] + ), + ] + + log.critical(f"{param_group_lr_scheduler_settings=}") + lr_scheduler = PerParamGroupLinearWarmupCosineAnnealingRLPLR( + optimizer, + param_group_lr_scheduler_settings, + lr_config.rlp._to_linear_warmup_cos_rlp_dict(), + max_epochs=next( + (s["max_epochs"] for s in param_group_lr_scheduler_settings) + ), + ) + out["lr_scheduler"] = { + "scheduler": lr_scheduler, + "interval": "step", + "frequency": 1, + } + case _: + assert_never(lr_config) + + return out + + @override + def configure_optimizers(self): + if self.config.parameter_specific_optimizers is not None: + return self.configure_optimizers_param_specific_optimizers( + self.config.parameter_specific_optimizers + ) + + optimizer = optimizer_from_config( + [(self.config.optimizer, self.parameters())], + ) + + out: dict[str, Any] = { + "optimizer": optimizer, + } + if (lr_config := self.config.lr_scheduler) is None: + return out + + assert isinstance( + lr_config, RLPConfig + ), "Only RLPConfig is supported if `parameter_specific_optimizers` is None" + if ( + lr_scheduler := self._construct_lr_scheduler(optimizer, lr_config) + ) is not None: + out["lr_scheduler"] = lr_scheduler + + return out + + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + def generate_graphs( + self, + data: BaseData, + cutoffs: Cutoffs, + max_neighbors: MaxNeighbors, + pbc: bool, + ): + aint_graph = generate_graph( + data, cutoff=cutoffs.aint, max_neighbors=max_neighbors.aint, pbc=pbc + ) + aint_graph = self.process_aint_graph(aint_graph) + subselect = partial( + subselect_graph, + data, + aint_graph, + cutoff_orig=cutoffs.aint, + max_neighbors_orig=max_neighbors.aint, + ) + main_graph = subselect(cutoffs.main, max_neighbors.main) + aeaint_graph = subselect(cutoffs.aeaint, max_neighbors.aeaint) + qint_graph = subselect(cutoffs.qint, max_neighbors.qint) + + # We can't do this at the data level: This is because the batch collate_fn doesn't know + # that it needs to increment the "id_swap" indices as it collates the data. + # So we do this at the graph level (which is done in the GemNetOC `get_graphs_and_indices` method). + # main_graph = symmetrize_edges(main_graph, num_atoms=data.pos.shape[0]) + qint_graph = tag_mask(data, qint_graph, tags=self.config.backbone.qint_tags) + + graphs = { + "main": main_graph, + "a2a": aint_graph, + "a2ee2a": aeaint_graph, + "qint": qint_graph, + } + + for graph_type, graph in graphs.items(): + for key, value in graph.items(): + setattr(data, f"{graph_type}_{key}", value) + + return data + + def create_dataset( + self, split: Literal["train", "val", "test"] + ) -> DatasetType | None: + match split: + case "train": + if (config := self.config.train_dataset) is None: + return None + case "val": + if (config := self.config.val_dataset) is None: + return None + case "test": + if (config := self.config.test_dataset) is None: + return None + case _: + assert_never(split) + + dataset = config.create_dataset() + dataset = wrap_common_dataset(dataset, config) + return dataset + + def validate_dataset(self, dataset: DatasetType): + if self.config.use_balanced_batch_sampler: + assert isinstance( + dataset, DatasetWithSizes + ), f"BalancedBatchSampler requires a DatasetWithSizes, but got {type(dataset)}" + + def _transform_cls_data(self, data: BaseData): + """ + Transforms the classification targets in the given data object based on the configuration. + + For binary classification targets, the target is converted to a float tensor (i.e., 0.0 or 1.0). + For multiclass classification targets, the target is converted to a long tensor (which is used as + the class index by `F.cross_entropy`). + + Args: + data (BaseData): The data object containing the classification targets. + + Returns: + BaseData: The transformed data object. + """ + for target_config in self.config.graph_classification_targets: + match target_config: + case BinaryClassificationTargetConfig(): + if (value := getattr(data, target_config.name, None)) is None: + log.warning(f"target {target_config.name} not found in data") + continue + + setattr(data, target_config.name, value.float()) + case MulticlassClassificationTargetConfig(): + if (value := getattr(data, target_config.name, None)) is None: + log.warning(f"target {target_config.name} not found in data") + continue + + setattr(data, target_config.name, value.long()) + case _: + pass + + return data + + def _apply_dataset_transforms(self, dataset: DatasetType): + dataset = DT.transform(dataset, self.data_transform) + if self.config.normalization: + dataset = DT.transform(dataset, T.normalize(self.config.normalization)) + dataset = DT.transform(dataset, self._transform_cls_data) + return dataset + + def train_dataset(self): + if (dataset := self.create_dataset("train")) is None: + return None + self.validate_dataset(dataset) + dataset = self._apply_dataset_transforms(dataset) + return dataset + + def val_dataset(self): + if (dataset := self.create_dataset("val")) is None: + return None + self.validate_dataset(dataset) + dataset = self._apply_dataset_transforms(dataset) + return dataset + + def test_dataset(self): + if (dataset := self.create_dataset("test")) is None: + return None + self.validate_dataset(dataset) + dataset = self._apply_dataset_transforms(dataset) + return dataset + + def distributed_sampler(self, dataset: Dataset, shuffle: bool): + return DistributedSampler( + dataset, + num_replicas=self.trainer.world_size, + rank=self.trainer.global_rank, + shuffle=shuffle, + ) + + @override + def train_dataloader(self): + if (dataset := self.train_dataset()) is None: + raise ValueError("No train dataset") + + sampler = self.distributed_sampler(dataset, shuffle=True) + if not self.config.use_balanced_batch_sampler: + data_loader = DataLoader( + dataset, + sampler=sampler, + batch_size=self.config.batch_size, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + ) + else: + batch_sampler = BalancedBatchSampler( + sampler, + batch_size=self.config.batch_size, + device=self.device, + ) + data_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + ) + + return data_loader + + @override + def val_dataloader(self): + if (dataset := self.val_dataset()) is None: + raise ValueError("No val dataset") + + sampler = self.distributed_sampler(dataset, shuffle=self.config.shuffle_val) + batch_size = self.config.eval_batch_size or self.config.batch_size + if not self.config.use_balanced_batch_sampler: + data_loader = DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + ) + else: + batch_sampler = BalancedBatchSampler( + sampler, + batch_size=batch_size, + device=self.device, + ) + data_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + ) + return data_loader + + @override + def test_dataloader(self): + if (dataset := self.test_dataset()) is None: + raise ValueError("No test dataset") + + sampler = self.distributed_sampler(dataset, shuffle=self.config.shuffle_test) + batch_size = self.config.eval_batch_size or self.config.batch_size + if not self.config.use_balanced_batch_sampler: + data_loader = DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + ) + else: + batch_sampler = BalancedBatchSampler( + sampler, + batch_size=batch_size, + device=self.device, + ) + data_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + ) + return data_loader + + def data_transform(self, data: BaseData): + return data + + def collate_fn(self, data_list: list[BaseData]): + return Batch.from_data_list(data_list) diff --git a/src/jmp/tasks/finetune/dataset_config.py b/src/jmp/tasks/finetune/dataset_config.py new file mode 100644 index 0000000..b5778fd --- /dev/null +++ b/src/jmp/tasks/finetune/dataset_config.py @@ -0,0 +1,173 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from pathlib import Path +from typing import Literal, TypeAlias + +from ...modules.dataset.common import DatasetAtomRefConfig +from .base import FinetuneLmdbDatasetConfig, FinetunePDBBindDatasetConfig +from .matbench import MatbenchDataset, MatbenchFold +from .md22 import MD22Molecule +from .rmd17 import RMD17Molecule +from .spice import SPICEDataset + +Split: TypeAlias = Literal["train", "val", "test"] + + +def matbench_config( + dataset: MatbenchDataset, + base_path: Path, + split: Split, + fold: MatbenchFold, +): + lmdb_path = base_path / "lmdb" / f"matbench_{dataset}" / f"{fold}" / f"{split}" + assert lmdb_path.exists(), f"{lmdb_path} does not exist" + + config = FinetuneLmdbDatasetConfig(src=lmdb_path) + return config + + +def rmd17_config( + molecule: RMD17Molecule, + base_path: Path, + split: Split, +): + lmdb_path = base_path / "lmdb" / f"{molecule}" / f"{split}" + assert lmdb_path.exists(), f"{lmdb_path} does not exist" + + config = FinetuneLmdbDatasetConfig(src=lmdb_path) + return config + + +def md22_config( + molecule: MD22Molecule, + base_path: Path, + split: Split, +): + lmdb_path = base_path / "lmdb" / f"{molecule}" / f"{split}" + assert lmdb_path.exists(), f"{lmdb_path} does not exist" + + config = FinetuneLmdbDatasetConfig(src=lmdb_path) + return config + + +def qm9_config( + base_path: Path, + split: Split, +): + lmdb_path = base_path / "lmdb" / f"{split}" + assert lmdb_path.exists(), f"{lmdb_path} does not exist" + + config = FinetuneLmdbDatasetConfig( + src=lmdb_path, + atom_ref=DatasetAtomRefConfig( + refs={ + "ZPVE": [ + 0.0000e00, + 3.1213e-01, + -1.0408e-17, + 2.7756e-17, + -1.3878e-17, + 0.0000e00, + 1.3694e-01, + 1.3955e-01, + 1.1424e-01, + 9.3038e-02, + ], + "U_0": [ + 0.0000e00, + -1.6430e01, + 1.1369e-13, + -4.5475e-13, + 0.0000e00, + 0.0000e00, + -1.0360e03, + -1.4898e03, + -2.0470e03, + -2.7175e03, + ], + "U": [ + 0.0000e00, + -1.6419e01, + 2.2737e-13, + -4.5475e-13, + 0.0000e00, + 0.0000e00, + -1.0360e03, + -1.4898e03, + -2.0470e03, + -2.7175e03, + ], + "H": [ + 0.0000e00, + -1.6419e01, + 3.4106e-13, + -4.5475e-13, + 0.0000e00, + 0.0000e00, + -1.0360e03, + -1.4898e03, + -2.0470e03, + -2.7175e03, + ], + "G": [ + 0.0000e00, + -1.6443e01, + 3.4106e-13, + -4.5475e-13, + 0.0000e00, + 0.0000e00, + -1.0361e03, + -1.4899e03, + -2.0471e03, + -2.7176e03, + ], + "c_v": [ + 0.0000e00, + 1.2409e00, + -3.3307e-16, + 4.4409e-16, + -1.1102e-16, + 0.0000e00, + 2.0350e00, + 2.7877e00, + 3.0860e00, + 3.3401e00, + ], + } + ), + ) + return config + + +def qmof_config( + base_path: Path, + split: Split, +): + lmdb_path = base_path / "lmdb" / f"{split}" + assert lmdb_path.exists(), f"{lmdb_path} does not exist" + + config = FinetuneLmdbDatasetConfig(src=lmdb_path) + return config + + +def spice_config( + dataset: SPICEDataset, + base_path: Path, + split: Split, +): + lmdb_path = base_path / "lmdb" / f"{dataset}" / f"{split}" + assert lmdb_path.exists(), f"{lmdb_path} does not exist" + + config = FinetuneLmdbDatasetConfig(src=lmdb_path) + return config + + +def pdbbind_config(split: Split): + config = FinetunePDBBindDatasetConfig(split=split) + return config diff --git a/src/jmp/tasks/finetune/energy_forces_base.py b/src/jmp/tasks/finetune/energy_forces_base.py new file mode 100644 index 0000000..afa3a10 --- /dev/null +++ b/src/jmp/tasks/finetune/energy_forces_base.py @@ -0,0 +1,326 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from contextlib import ExitStack +from logging import getLogger +from typing import Any, Generic, Literal, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import pack, rearrange, reduce +from jmp.lightning import TypedConfig +from torch_geometric.data.data import BaseData +from torch_scatter import scatter +from typing_extensions import TypeVar, override + +from ...models.gemnet.backbone import GOCBackboneOutput +from ...models.gemnet.layers.force_scaler import ForceScaler +from ...modules.dataset import dataset_transform as DT +from .base import FinetuneConfigBase, FinetuneModelBase + +log = getLogger(__name__) + + +class PretrainOutputHeadConfig(TypedConfig): + enabled: bool = False + + num_pretrain_heads: int = 4 + energy_reduction: str = "sum" + + direct_forces: bool = False + gradient_forces: bool = False + + combine_strategy: str = "mean" + + +class EnergyForcesConfigBase(FinetuneConfigBase): + graph_scalar_targets: list[str] = ["y"] + node_vector_targets: list[str] = ["force"] + + graph_scalar_loss_coefficients: dict[str, float] = {"y": 1.0} + node_vector_loss_coefficients: dict[str, float] = {"force": 100.0} + + gradient_forces: bool = False + model_type: Literal["energy", "forces", "energy_forces"] = "energy_forces" + + pretrain_output_head: PretrainOutputHeadConfig = PretrainOutputHeadConfig() + + @override + def __post_init__(self): + super().__post_init__() + + if self.gradient_forces: + assert ( + not self.trainer.inference_mode + ), "Gradient forces requires trainer.inference_mode = False" + + +TConfig = TypeVar("TConfig", bound=EnergyForcesConfigBase, infer_variance=True) + + +class EnergyForcesModelBase( + FinetuneModelBase[TConfig], nn.Module, ABC, Generic[TConfig] +): + @override + def validate_config(self, config: TConfig): + super().validate_config(config) + + assert config.model_type in ("energy", "forces", "energy_forces"), ( + f"{config.model_type=} must be one of these values: " + "energy, forces, energy_forces" + ) + + # if config.gradient_forces: + # assert ( + # not config.trainer.inference_mode + # ), f"config.trainer.inference_mode must be False when {config.gradient_forces=}" + + @override + def __init__(self, hparams: TConfig): + super().__init__(hparams) + + if self.config.gradient_forces: + self.force_scaler = ForceScaler() + + @override + def load_backbone_state_dict( + self, + *, + backbone: Mapping[str, Any], + embedding: Mapping[str, Any], + output: Mapping[str, Any] | None = None, + strict: bool = True, + ): + super().load_backbone_state_dict( + backbone=backbone, embedding=embedding, strict=strict + ) + + if self.config.pretrain_output_head.enabled: + assert ( + output is not None + ), "output must be provided when pretrain_output_head is enabled" + self._load_pretrain_output_state_dict(output) + + def construct_energy_head(self): + def dims( + emb_size: int, + *, + num_targets: int = self.config.backbone.num_targets, + num_mlps: int = self.config.output.num_mlps, + ): + return ([emb_size] * num_mlps) + [num_targets] + + self.out_energy = self.mlp( + dims(self.config.backbone.emb_size_atom), + activation=self.config.activation_cls, + bias=False, + ) + + @override + def construct_output_heads(self): + def dims( + emb_size: int, + *, + num_targets: int = self.config.backbone.num_targets, + num_mlps: int = self.config.output.num_mlps, + ): + return ([emb_size] * num_mlps) + [num_targets] + + self.out_energy = None + if ( + self.config.model_type in ("energy", "energy_forces") + or self.config.gradient_forces + ): + self.construct_energy_head() + + self.out_forces = None + if ( + self.config.model_type in ("forces", "energy_forces") + and not self.config.gradient_forces + ): + self.out_forces = self.mlp( + dims(self.config.backbone.emb_size_edge), + activation=self.config.activation_cls, + ) + + @override + def outhead_parameters(self): + head_params: list[nn.Parameter] = [] + if self.out_energy is not None: + head_params.extend(self.out_energy.parameters()) + if self.out_forces is not None: + head_params.extend(self.out_forces.parameters()) + return head_params + + def combine_outputs( + self, + energy_list: list[torch.Tensor], + forces_list: list[torch.Tensor], + ): + energy, _ = pack(energy_list, "b *") # (bsz, T) + forces, _ = pack(forces_list, "n p *") # (N, 3, T) + + match self.config.pretrain_output_head.combine_strategy: + case "mean": + energy = reduce(energy, "b T -> b", "mean") + forces = reduce(forces, "n p T -> n p", "mean") + case _: + raise ValueError( + f"Unknown combine strategy: {self.config.pretrain_output_head.combine_strategy}" + ) + + return energy, forces + + @override + def forward(self, data: BaseData): + preds: dict[str, torch.Tensor] = {} + with ExitStack() as stack: + if self.config.gradient_forces or ( + self.config.pretrain_output_head.enabled + and self.config.pretrain_output_head.gradient_forces + ): + stack.enter_context(torch.inference_mode(mode=False)) + stack.enter_context(torch.enable_grad()) + + data.pos.requires_grad_(True) + data = self.generate_graphs_transform(data) + + atomic_numbers = data.atomic_numbers - 1 + h = self.embedding(atomic_numbers) + out: GOCBackboneOutput = self.backbone(data, h=h) + + n_molecules = int(torch.max(data.batch).item() + 1) + n_atoms = data.atomic_numbers.shape[0] + + if self.out_energy is not None: + output = self.out_energy(out["energy"]) # (n_atoms, 1) + + # TODO: set reduce to config + output = scatter( + output, + data.batch, + dim=0, + dim_size=n_molecules, + reduce="sum", + ) + preds["y"] = rearrange(output, "b 1 -> b") + + if self.out_forces is not None: + output = self.out_forces(out["forces"]) + output = output * out["V_st"] + output = scatter( + output, out["idx_t"], dim=0, dim_size=n_atoms, reduce="sum" + ) + preds["force"] = output + + if self.config.gradient_forces: + assert "force" not in preds, f"force already in preds: {preds.keys()}" + assert ( + energy := preds.get("y") + ) is not None, f"energy not in preds: {preds.keys()}" + preds["force"] = self.force_scaler.calc_forces_and_update( + energy, data.pos + ) + + if self.config.pretrain_output_head.enabled: + pretrain_energies, pretrain_forces = self.pretrain_output( + data, out + ) # (bsz, T), (N, 3, T) + + pretrain_energies = cast(list[torch.Tensor], pretrain_energies) + pretrain_forces = cast(list[torch.Tensor], pretrain_forces) + + gradient_forces: list[torch.Tensor] = [] + if self.config.pretrain_output_head.gradient_forces: + for energy in pretrain_energies: + # energy: (bsz) + forces = self.force_scaler.calc_forces_and_update( + energy, data.pos + ) # (N, 3) + gradient_forces.append(forces) + + all_energies = [preds["y"]] + pretrain_energies + all_forces = [preds["force"]] + pretrain_forces + gradient_forces + + preds["y"], preds["force"] = self.combine_outputs( + all_energies, all_forces + ) + + return preds + + @override + def compute_losses(self, batch: BaseData, preds: dict[str, torch.Tensor]): + losses: list[torch.Tensor] = [] + + if self.config.model_type in ("energy", "energy_forces"): + loss = F.l1_loss(preds["y"], batch["y"]) + self.log("y_loss", loss) + + coef = self.config.graph_scalar_loss_coefficients.get( + "y", self.config.graph_scalar_loss_coefficient_default + ) + loss = coef * loss + self.log("y_loss_scaled", loss) + losses.append(loss) + + if self.config.model_type in ("forces", "energy_forces"): + assert preds["force"].shape[-1] == 3, f"{preds['force'].shape=}" + + loss = F.pairwise_distance(preds["force"], batch["force"], p=2.0).mean() + self.log("force_loss", loss) + + coef = self.config.node_vector_loss_coefficients.get( + "force", self.config.node_vector_loss_coefficient_default + ) + loss = coef * loss + self.log("force_loss_scaled", loss) + + losses.append(loss) + + loss = sum(losses) + self.log("loss", loss) + + return loss + + @abstractmethod + def generate_graphs_transform(self, data: BaseData) -> BaseData: ... + + def _generate_graphs_transform(self, data: BaseData): + if self.config.gradient_forces: + # We need to compute the graphs in the forward method + # so that we can compute the forces using the energy + # and the positions. + return data + return self.generate_graphs_transform(data) + + @override + def train_dataset(self): + if (dataset := super().train_dataset()) is None: + return None + + dataset = DT.transform(dataset, transform=self._generate_graphs_transform) + return dataset + + @override + def val_dataset(self): + if (dataset := super().val_dataset()) is None: + return None + + dataset = DT.transform(dataset, transform=self._generate_graphs_transform) + return dataset + + @override + def test_dataset(self): + if (dataset := super().test_dataset()) is None: + return None + + dataset = DT.transform(dataset, transform=self._generate_graphs_transform) + return dataset diff --git a/src/jmp/tasks/finetune/matbench.py b/src/jmp/tasks/finetune/matbench.py new file mode 100644 index 0000000..ddd8e4c --- /dev/null +++ b/src/jmp/tasks/finetune/matbench.py @@ -0,0 +1,97 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Literal, TypeAlias, final + +import torch +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ...utils.goc_graph import Cutoffs, Graph, MaxNeighbors +from . import base + +MatbenchDataset: TypeAlias = Literal[ + "jdft2d", + "phonons", + "dielectric", + "log_gvrh", + "log_kvrh", + "perovskites", + "mp_gap", + "mp_e_form", + "mp_is_metal", +] +MatbenchFold: TypeAlias = Literal[0, 1, 2, 3, 4] + + +class MatbenchConfig(base.FinetuneConfigBase): + dataset: MatbenchDataset + graph_scalar_targets: list[str] = ["y"] + node_vector_targets: list[str] = [] + + graph_scalar_reduction_default: Literal["sum", "mean", "max"] = "mean" + + fold: MatbenchFold = 0 + mp_e_form_dev: bool = True + + conditional_max_neighbors: bool = False + + +@final +class MatbenchModel(base.FinetuneModelBase[MatbenchConfig]): + @classmethod + @override + def config_cls(cls): + return MatbenchConfig + + @override + def metric_prefix(self) -> str: + return f"matbench/{self.config.dataset}" + + @override + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + @override + def data_transform(self, data: BaseData): + data = super().data_transform(data) + + if not torch.is_tensor(data.y): + data.y = torch.tensor(data.y) + data.y = data.y.view(-1) + + if self.config.dataset == "mp_is_metal": + data.y = data.y.bool() + + data.atomic_numbers = data.atomic_numbers.long() + assert data.num_nodes is not None + data.natoms = data.num_nodes + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + data.pos = data.pos.float() + + max_neighbors = 30 + if self.config.conditional_max_neighbors: + if data.natoms > 300: + max_neighbors = 5 + elif data.natoms > 200: + max_neighbors = 10 + else: + max_neighbors = 30 + + data = self.generate_graphs( + data, + cutoffs=Cutoffs.from_constant(12.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(max_neighbors), + pbc=True, + ) + return data diff --git a/src/jmp/tasks/finetune/md22.py b/src/jmp/tasks/finetune/md22.py new file mode 100644 index 0000000..87061c4 --- /dev/null +++ b/src/jmp/tasks/finetune/md22.py @@ -0,0 +1,77 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Literal, TypeAlias, final + +import torch +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ...utils.goc_graph import Cutoffs, Graph, MaxNeighbors +from .energy_forces_base import EnergyForcesConfigBase, EnergyForcesModelBase + +MD22Molecule: TypeAlias = Literal[ + "Ac-Ala3-NHMe", + "DHA", + "stachyose", + "AT-AT", + "AT-AT-CG-CG", + "buckyball-catcher", + "double-walled_nanotube", +] + + +class MD22Config(EnergyForcesConfigBase): + molecule: MD22Molecule + + graph_scalar_targets: list[str] = ["y"] + node_vector_targets: list[str] = ["force"] + + graph_scalar_loss_coefficients: dict[str, float] = {"y": 1.0} + node_vector_loss_coefficients: dict[str, float] = {"force": 100.0} + + +@final +class MD22Model(EnergyForcesModelBase[MD22Config]): + @classmethod + @override + def config_cls(cls): + return MD22Config + + @override + def metric_prefix(self) -> str: + return f"md22/{self.config.molecule}" + + @override + def generate_graphs_transform(self, data: BaseData): + return self.generate_graphs( + data, + cutoffs=Cutoffs.from_constant(12.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=False, + ) + + @override + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + @override + def data_transform(self, data: BaseData): + data = super().data_transform(data) + + data.y = data.pop("y").view(-1).float() + data.atomic_numbers = data.pop("atomic_numbers").long() + data.natoms = data.num_nodes + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + data.cell = (torch.eye(3) * 1000.0).unsqueeze(dim=0) + return data diff --git a/src/jmp/tasks/finetune/metrics.py b/src/jmp/tasks/finetune/metrics.py new file mode 100644 index 0000000..038163e --- /dev/null +++ b/src/jmp/tasks/finetune/metrics.py @@ -0,0 +1,347 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections import Counter +from dataclasses import dataclass +from logging import getLogger +from typing import TYPE_CHECKING, Protocol, cast, runtime_checkable + +import torch +import torch.nn as nn +import torchmetrics +import torchmetrics.classification +from frozendict import frozendict +from jmp.lightning import TypedConfig +from jmp.lightning.util.typed import TypedModuleDict +from torch_geometric.data.data import BaseData +from typing_extensions import override + +if TYPE_CHECKING: + from .base import ( + BinaryClassificationTargetConfig, + MulticlassClassificationTargetConfig, + ) + +log = getLogger(__name__) + + +class CheckConflictingStructuresConfig(TypedConfig): + structures: dict[str, str] = {} + """ + A dictionary which maps from pre-training dataset names to the path of the + pickle files (saved using `torch.save` as sets of frozendict[int, int] objects) + containing the unique atomic numbers in the dataset. + + The frozendict[int, int] objects are mappings from atomic numbers to the number + of atoms with that atomic number in the structure. + """ + + all: bool = True + """Also check for conflicting structures across all datasets (i.e., the union of all structures)""" + + +class MetricsConfig(TypedConfig): + check_conflicting_structures: CheckConflictingStructuresConfig | None = None + """ + Configuration for checking conflicting structures. + + This is used to, for example, check to see what percentage + of structures in the fine-tuning dataset (e.g., in QM9) exist + in the pre-training dataset (e.g., ANI-1x). + """ + + report_rmse: bool = False + """Whether to report RMSE in addition to MAE""" + + +@dataclass(frozen=True) +class MetricPair: + predicted: torch.Tensor + ground_truth: torch.Tensor + + +@runtime_checkable +class MetricPairProvider(Protocol): + def __call__( + self, prop: str, batch: BaseData, preds: dict[str, torch.Tensor] + ) -> MetricPair | None: ... + + +class ConflictingMetrics(nn.Module): + @override + def __init__( + self, + graph_targets: list[str], + node_targets: list[str], + structures: set[frozendict[int, int]], + provider: MetricPairProvider, + ): + super().__init__() + + self.graph_targets = graph_targets + self.node_targets = node_targets + self.targets = graph_targets + node_targets + + self.structures = structures + self.conflicting_maes = TypedModuleDict( + {target: torchmetrics.MeanAbsoluteError() for target in self.targets} + ) + self.non_conflicting_maes = TypedModuleDict( + {target: torchmetrics.MeanAbsoluteError() for target in self.targets} + ) + + self.num_conflicting = torchmetrics.SumMetric() + self.num_non_conflicting = torchmetrics.SumMetric() + self.num_total = torchmetrics.SumMetric() + + self.provider = provider + + def _compute_mask(self, data: BaseData): + n_graphs = int(torch.max(data.batch).item() + 1) + mask = torch.zeros(n_graphs, dtype=torch.bool, device=data.batch.device) + for i in range(n_graphs): + # get the atomic numbers for the current molecule + atomic_numbers = data.atomic_numbers[data.batch == i].long() + atomic_numbers_dict = frozendict(Counter(atomic_numbers.tolist())) + mask[i] = atomic_numbers_dict in self.structures + return mask + + def _compute_metrics( + self, + targets: list[str], + batch: BaseData, + preds: dict[str, torch.Tensor], + mask: torch.Tensor, + ): + for key in targets: + if (mp := self.provider(key, batch, preds)) is None: + continue + + conflicting_mae = self.conflicting_maes[key] + non_conflicting_mae = self.non_conflicting_maes[key] + + conflicting_mae(mp.predicted[mask], mp.ground_truth[mask]) + non_conflicting_mae(mp.predicted[~mask], mp.ground_truth[~mask]) + + @override + def forward(self, batch: BaseData, preds: dict[str, torch.Tensor]): + mask = self._compute_mask(batch) + + self.num_conflicting(mask) + self.num_non_conflicting(~mask) + self.num_total(torch.ones_like(mask)) + + self._compute_metrics(self.graph_targets, batch, preds, mask) + self._compute_metrics(self.node_targets, batch, preds, mask[batch.batch]) + + metrics: dict[str, torchmetrics.Metric] = {} + metrics["num_conflicting"] = self.num_conflicting + metrics["num_non_conflicting"] = self.num_non_conflicting + metrics["num_total"] = self.num_total + + for key in self.targets: + metrics[f"{key}_conflicting_mae"] = self.conflicting_maes[key] + metrics[f"{key}_non_conflicting_mae"] = self.non_conflicting_maes[key] + + return metrics + + +class BinaryClassificationMetrics(nn.Module): + def __init__(self, num_classes: int): + super().__init__() + + assert num_classes == 2, "Only binary classification is supported" + + self.roc_auc = torchmetrics.classification.BinaryAUROC() + self.f1 = torchmetrics.classification.F1Score(task="binary") + self.balanced_accuracy = torchmetrics.classification.MulticlassAccuracy( + average="macro", num_classes=2 + ) + + def compute(self): + # This method returns a Tensor which contains the metric used for RLP + return self.balanced_accuracy.compute() + + @override + def forward(self, pred: torch.Tensor, target: torch.Tensor): + metrics: dict[str, torchmetrics.Metric] = {} + + self.roc_auc(pred, target) + self.f1(pred, target) + + # For balanced accuracy, we need to convert the binary pred/target to a + # multiclass target with 2 classes. This is because torchmetrics' implementation + # of torchmetrics.classification.BinaryAccuracy does not support the "macro" + # average, which is what we want. + cls_pred = pred.new_zeros((*pred.shape, 2)) + cls_pred[..., 1] = pred + cls_pred[..., 0] = 1 - pred + + cls_target = target.long() + self.balanced_accuracy(cls_pred, cls_target) + + metrics["roc_auc"] = self.roc_auc + metrics["f1"] = self.f1 + metrics["balanced_accuracy"] = self.balanced_accuracy + + return metrics + + +class MulticlassClassificationMetrics(nn.Module): + def __init__(self, num_classes: int): + super().__init__() + + self.roc_auc = torchmetrics.classification.AUROC( + task="multiclass", num_classes=num_classes + ) + self.f1 = torchmetrics.classification.F1Score( + task="multiclass", num_classes=num_classes + ) + self.balanced_accuracy = torchmetrics.classification.MulticlassAccuracy( + average="macro", num_classes=num_classes + ) + self.num_classes = num_classes + + def compute(self): + # This method returns a Tensor which contains the metric used for RLP + return self.balanced_accuracy.compute() + + @override + def forward(self, pred: torch.Tensor, target: torch.Tensor): + metrics: dict[str, torchmetrics.Metric] = {} + + self.roc_auc(pred, target) + self.f1(pred, target) + self.balanced_accuracy(pred, target) + + metrics["roc_auc"] = self.roc_auc + metrics["f1"] = self.f1 + metrics["balanced_accuracy"] = self.balanced_accuracy + + return metrics + + +class FinetuneMetrics(nn.Module): + @property + def regression_targets(self): + return self.graph_scalar_targets + self.node_vector_targets + + @override + def __init__( + self, + config: MetricsConfig, + provider: MetricPairProvider, + graph_scalar_targets: list[str], + graph_classification_targets: "list[BinaryClassificationTargetConfig | MulticlassClassificationTargetConfig]", + node_vector_targets: list[str], + ): + super().__init__() + + if not isinstance(provider, MetricPairProvider): + raise TypeError( + f"Expected {provider=} to be an instance of {MetricPairProvider=}" + ) + self.provider = provider + + self.config = config + self.graph_scalar_targets = graph_scalar_targets + self.graph_classification_targets = graph_classification_targets + self.node_vector_targets = node_vector_targets + + self.maes = TypedModuleDict( + { + target: torchmetrics.MeanAbsoluteError() + for target in self.regression_targets + }, + key_prefix="mae_", + ) + if self.config.report_rmse: + self.rmses = TypedModuleDict( + { + target: torchmetrics.MeanSquaredError(squared=False) + for target in self.regression_targets + }, + key_prefix="rmse_", + ) + self.cls_metrics = TypedModuleDict( + { + target.name: ( + BinaryClassificationMetrics(target.num_classes) + if isinstance(target, BinaryClassificationTargetConfig) + else MulticlassClassificationMetrics(target.num_classes) + ) + for target in self.graph_classification_targets + }, + key_prefix="cls_", + ) + + if (ccs := self.config.check_conflicting_structures) is not None: + metrics_dict: dict[str, ConflictingMetrics] = {} + all_structures = set[frozendict[int, int]]() + for name, structures in ccs.structures.items(): + structures = cast(set[frozendict[int, int]], torch.load(structures)) + metrics_dict[name] = ConflictingMetrics( + graph_targets=self.graph_scalar_targets, + node_targets=self.node_vector_targets, + structures=structures, + provider=self.provider, + ) + if ccs.all: + all_structures.update(structures) + + if ccs.all: + metrics_dict["all"] = ConflictingMetrics( + graph_targets=self.graph_scalar_targets, + node_targets=self.node_vector_targets, + structures=all_structures, + provider=self.provider, + ) + self.conflicting = TypedModuleDict(metrics_dict) + + @override + def forward(self, batch: BaseData, preds: dict[str, torch.Tensor]): + metrics: dict[str, torchmetrics.Metric] = {} + + for key, mae in self.maes.items(): + if (mp := self.provider(key, batch, preds)) is None: + continue + + mae(mp.predicted, mp.ground_truth) + metrics[f"{key}_mae"] = mae + + if self.config.report_rmse: + for key, rmse in self.rmses.items(): + if (mp := self.provider(key, batch, preds)) is None: + continue + + rmse(mp.predicted, mp.ground_truth) + metrics[f"{key}_rmse"] = rmse + + for key, cls_metric in self.cls_metrics.items(): + if (mp := self.provider(key, batch, preds)) is None: + continue + + metric_dict = cls_metric(mp.predicted, mp.ground_truth) + metrics.update( + { + f"{key}_{metric_name}": metric + for metric_name, metric in metric_dict.items() + } + ) + + if self.config.check_conflicting_structures is not None: + for name, conflicting in self.conflicting.items(): + metric_dict = conflicting(batch, preds) + metrics.update( + { + f"conflicting/{name}_{metric_name}": metric + for metric_name, metric in metric_dict.items() + } + ) + + return metrics diff --git a/src/jmp/tasks/finetune/pdbbind.py b/src/jmp/tasks/finetune/pdbbind.py new file mode 100644 index 0000000..d65a3bd --- /dev/null +++ b/src/jmp/tasks/finetune/pdbbind.py @@ -0,0 +1,146 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from dataclasses import replace +from typing import final + +import torch +from torch_geometric.data.batch import Batch +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ...datasets.finetune_pdbbind import PDBBindTask +from ...utils.goc_graph import Cutoffs, Graph, MaxNeighbors +from .base import FinetuneConfigBase, FinetuneModelBase, FinetunePDBBindDatasetConfig +from .metrics import MetricPair + + +class PDBBindConfig(FinetuneConfigBase): + pbdbind_task: PDBBindTask + + graph_scalar_targets: list[str] = ["y"] + node_vector_targets: list[str] = [] + + cutoff: float = 12.0 + max_neighbors: int = 30 + pbc: bool = False + + @override + def __post_init__(self): + super().__post_init__() + + all_datasets: list[FinetunePDBBindDatasetConfig] = [] + if self.train_dataset is not None: + assert isinstance( + self.train_dataset, FinetunePDBBindDatasetConfig + ), "dataset config must be of type FinetunePDBBindDatasetConfig" + all_datasets.append(self.train_dataset) + + if self.val_dataset is not None: + assert isinstance( + self.val_dataset, FinetunePDBBindDatasetConfig + ), "dataset config must be of type FinetunePDBBindDatasetConfig" + all_datasets.append(self.val_dataset) + + if self.test_dataset is not None: + assert isinstance( + self.test_dataset, FinetunePDBBindDatasetConfig + ), "dataset config must be of type FinetunePDBBindDatasetConfig" + all_datasets.append(self.test_dataset) + + # Make sure all datasets have the same task + for dataset in all_datasets: + assert ( + dataset.task == self.pbdbind_task + ), "All datasets must have the same task" + + +@final +class PDBBindModel(FinetuneModelBase[PDBBindConfig]): + @classmethod + @override + def config_cls(cls): + return PDBBindConfig + + @override + def metric_prefix(self) -> str: + return f"pdbbind/{self.config.pbdbind_task}" + + @override + def metrics_provider( + self, + prop: str, + batch: Batch, + preds: dict[str, torch.Tensor], + ) -> MetricPair | None: + """ + For PDBbind, the moleculenet dataset already normalizes the properties when it gives it to us. + Therefore, we need to change the logic for unnormalizing the properties for metrics. + """ + if (pair := super().metrics_provider(prop, batch, preds)) is None: + return None + + # Get the mean and std of the property from the dataset and unnormalize for metrics + mean = getattr( + batch, + f"{prop}_mean", + torch.tensor( + 0.0, + device=pair.ground_truth.device, + dtype=pair.ground_truth.dtype, + ), + ) + std = getattr( + batch, + f"{prop}_std", + torch.tensor( + 1.0, + device=pair.ground_truth.device, + dtype=pair.ground_truth.dtype, + ), + ) + + pair = replace( + pair, + ground_truth=(pair.ground_truth * std) + mean, + predicted=(pair.predicted * std) + mean, + ) + return pair + + @override + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + @override + def data_transform(self, data: BaseData): + data = super().data_transform(data) + + data = copy.deepcopy(data) + if not torch.is_tensor(data.y): + data.y = torch.tensor(data.y, dtype=torch.float) + data.y = data.y.float().view(-1) + data.atomic_numbers = data.atomic_numbers.long() + data.natoms = data.num_nodes + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + data.pos = data.pos.float() + data = self.generate_graphs( + data, + cutoffs=Cutoffs.from_constant(self.config.cutoff), + max_neighbors=MaxNeighbors.from_goc_base_proportions( + self.config.max_neighbors + ), + pbc=self.config.pbc, + ) + + return data diff --git a/src/jmp/tasks/finetune/qm9.py b/src/jmp/tasks/finetune/qm9.py new file mode 100644 index 0000000..b528898 --- /dev/null +++ b/src/jmp/tasks/finetune/qm9.py @@ -0,0 +1,239 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Callable +from typing import Annotated, Literal, TypeAlias, assert_never, final + +import torch +import torch.nn as nn +from ase.data import atomic_masses +from einops import rearrange +from jmp.lightning import Field, TypedConfig +from torch_geometric.data.data import BaseData +from torch_scatter import scatter +from typing_extensions import override + +from ...utils.goc_graph import Cutoffs, Graph, MaxNeighbors +from .base import FinetuneConfigBase, FinetuneModelBase, OutputHeadInput + +QM9Target: TypeAlias = Literal[ + "mu", # dipole_moment + "alpha", # isotropic_polarizability + "eps_HOMO", # hOMO + "eps_LUMO", # lumo + "delta_eps", # homo_lumo_gap + "R_2_Abs", # electronicspatial_extent + "ZPVE", # zpve + "U_0", # energy_U0 + "U", # energy_U + "H", # enthalpy_H + "G", # free_energy + "c_v", # heat_capacity + "U_0_ATOM", # atomization_energy_U0 + "U_ATOM", # atomization_energy_U + "H_ATOM", # atomization_enthalpy_H + "G_ATOM", # atomization_free_energy + "A", # rotational_constant_A + "B", # rotational_constant_B + "C", # rotational_constant_C +] + + +class DefaultOutputHeadConfig(TypedConfig): + name: Literal["default"] = "default" + + +class SpatialExtentConfig(TypedConfig): + name: Literal["spatial_extent"] = "spatial_extent" + + +OutputHeadConfig: TypeAlias = Annotated[ + DefaultOutputHeadConfig | SpatialExtentConfig, + Field(discriminator="name"), +] + + +class QM9Config(FinetuneConfigBase): + graph_scalar_targets: list[str] = [] + node_vector_targets: list[str] = [] + + graph_scalar_reduction: dict[str, Literal["sum", "mean", "max"]] = { + "mu": "sum", + "alpha": "sum", + "eps_HOMO": "sum", + "eps_LUMO": "sum", + "delta_eps": "sum", + "R_2_Abs": "sum", + "ZPVE": "sum", + "U_0": "sum", + "U": "sum", + "H": "sum", + "G": "sum", + "c_v": "sum", + } + + output_head: OutputHeadConfig = DefaultOutputHeadConfig() + + max_neighbors: int = 30 + + +class SpatialExtentOutputHead(nn.Module): + @override + def __init__(self, atomic_masses: Callable[[], torch.Tensor], reduction: str): + super().__init__() + + if reduction is None: + reduction = self.config.graph_scalar_reduction_default + assert reduction == "sum", f"reduction must be sum, got {self.reduction=}" + + dim = self.config.backbone.emb_size_atom + scalar_mlp_layers: list[nn.Module] = [] + for _ in range(self.config.output.num_mlps): + scalar_mlp_layers.append(nn.Linear(dim, dim, bias=False)) + scalar_mlp_layers.append(self.config.activation_cls()) + scalar_mlp_layers.append(nn.Linear(dim, 1, bias=False)) + self.scalar_mlp = nn.Sequential(*scalar_mlp_layers) + + self.atomic_masses = atomic_masses + self.reduction = reduction + + @override + def forward(self, input: OutputHeadInput): + data = input["data"] + backbone_output = input["backbone_output"] + x = self.scalar_mlp(backbone_output["energy"]) # n 1 + + batch_size = int(torch.max(data.batch).item() + 1) + + # Get the center of mass + masses = self.atomic_masses()[data.atomic_numbers] # n + center_of_mass = scatter( + masses.unsqueeze(-1) * data.pos, # n 3 + data.batch, + dim=0, + dim_size=batch_size, + reduce="sum", + ) / scatter( + masses.unsqueeze(-1), + data.batch, + dim=0, + dim_size=batch_size, + reduce="sum", + ) # b 3 + + # Get the squared norm of each position vector + pos_norm_sq = ( + torch.linalg.vector_norm( + data.pos - center_of_mass[data.batch], + dim=-1, + keepdim=True, + ord=2, + ) + ** 2 + ) # n 1 + x = x * pos_norm_sq # n 1 + + # Apply the reduction + x = scatter( + x, + data.batch, + dim=0, + dim_size=batch_size, + reduce=self.reduction, + ) # (bsz, 1) + + x = rearrange(x, "b 1 -> b") + return x + + +@final +class QM9Model(FinetuneModelBase[QM9Config]): + targets: list[QM9Target] = [ + "mu", + "alpha", + "eps_HOMO", + "eps_LUMO", + "delta_eps", + "R_2_Abs", + "ZPVE", + "U_0", + "U", + "H", + "G", + "c_v", + "U_0_ATOM", + "U_ATOM", + "H_ATOM", + "G_ATOM", + "A", + "B", + "C", + ] + + atomic_masses: torch.Tensor + + @override + def __init__(self, hparams: QM9Config): + super().__init__(hparams) + + self.register_buffer( + "atomic_masses", + torch.from_numpy(atomic_masses).float(), + persistent=False, + ) + + @override + def validate_config(self, config: QM9Config): + super().validate_config(config) + + for key in config.graph_scalar_targets: + assert key in self.targets, f"{key} is not a valid QM9 target" + + @classmethod + @override + def config_cls(cls): + return QM9Config + + @override + def metric_prefix(self) -> str: + return "qm9" + + @override + def construct_graph_scalar_output_head(self, target: str): + reduction = self.config.graph_scalar_reduction.get( + target, self.config.graph_scalar_reduction_default + ) + match self.config.output_head: + case SpatialExtentConfig(): + # This is only supported for R_2_Abs + assert ( + target == "R_2_Abs" + ), f"{target} is not supported for spatial extent" + + return SpatialExtentOutputHead(lambda: self.atomic_masses, reduction) + case DefaultOutputHeadConfig(): + return super().construct_graph_scalar_output_head(target) + case _: + assert_never(self.config.output_head) + + @override + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + @override + def data_transform(self, data: BaseData): + data = super().data_transform(data) + + data = self.generate_graphs( + data, + cutoffs=Cutoffs.from_constant(8.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=False, + ) + + return data diff --git a/src/jmp/tasks/finetune/qmof.py b/src/jmp/tasks/finetune/qmof.py new file mode 100644 index 0000000..01ec9c1 --- /dev/null +++ b/src/jmp/tasks/finetune/qmof.py @@ -0,0 +1,87 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from typing import Literal, final + +import torch +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ...utils.goc_graph import Cutoffs, Graph, MaxNeighbors +from .base import FinetuneConfigBase, FinetuneModelBase + + +class QMOFConfig(FinetuneConfigBase): + graph_scalar_targets: list[str] = ["y"] + node_vector_targets: list[str] = [] + + graph_scalar_reduction_default: Literal["sum", "mean", "max"] = "mean" + + +@final +class QMOFModel(FinetuneModelBase[QMOFConfig]): + @classmethod + @override + def config_cls(cls): + return QMOFConfig + + @override + def metric_prefix(self) -> str: + return "qmof" + + @override + def training_step(self, batch, batch_idx): + with self.log_context(prefix=f"train/{self.metric_prefix()}/"): + preds = self(batch) + + loss = self.compute_losses(batch, preds) + self.log_dict(self.train_metrics(batch, preds)) + + return loss + + @override + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + @override + def data_transform(self, data: BaseData): + data = super().data_transform(data) + + data = copy.deepcopy(data) + if not torch.is_tensor(data.y): + data.y = torch.tensor(data.y) + data.y = data.y.view(-1) + data.atomic_numbers = data.atomic_numbers.long() + data.natoms = data.num_nodes + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + data.pos = data.pos.float() + + cutoff = 19 + if data.natoms > 300: + max_neighbors = 5 + elif data.natoms > 200: + max_neighbors = 10 + else: + max_neighbors = 30 + + data = self.generate_graphs( + data, + cutoffs=Cutoffs.from_constant(cutoff), + max_neighbors=MaxNeighbors.from_goc_base_proportions(max_neighbors), + # cutoffs=Cutoffs.from_constant(12.0), + # max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=True, + ) + + return data diff --git a/src/jmp/tasks/finetune/rmd17.py b/src/jmp/tasks/finetune/rmd17.py new file mode 100644 index 0000000..bea25a7 --- /dev/null +++ b/src/jmp/tasks/finetune/rmd17.py @@ -0,0 +1,97 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from typing import Literal, TypeAlias + +import torch +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ...utils.goc_graph import Cutoffs, Graph, MaxNeighbors +from .energy_forces_base import EnergyForcesConfigBase, EnergyForcesModelBase + +RMD17Molecule: TypeAlias = Literal[ + "aspirin", + "azobenzene", + "benzene", + "ethanol", + "malonaldehyde", + "naphthalene", + "paracetamol", + "salicylic", + "toluene", + "uracil", +] + + +class RMD17Config(EnergyForcesConfigBase): + molecule: RMD17Molecule + + graph_scalar_loss_coefficients: dict[str, float] = {"y": 0.0} + node_vector_loss_coefficients: dict[str, float] = {"force": 10.0} + + cutoff: float = 7.0 + max_neighbors: int = 100 + + +class RMD17Model(EnergyForcesModelBase[RMD17Config]): + @classmethod + @override + def config_cls(cls): + return RMD17Config + + @override + def metric_prefix(self) -> str: + return f"md17/{self.config.molecule}" + + @override + def generate_graphs_transform(self, data: BaseData): + return self.generate_graphs( + data, + cutoffs=Cutoffs.from_constant(self.config.cutoff), + max_neighbors=MaxNeighbors.from_goc_base_proportions( + self.config.max_neighbors + ), + pbc=False, + ) + + @override + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + @override + def data_transform(self, data: BaseData): + data = super().data_transform(data) + + data = copy.deepcopy(data) + + if not torch.is_tensor(data.y): + data.y = torch.tensor(data.y, dtype=torch.float) + data.y = data.y.view(-1).float() + if hasattr(data, "z"): + data.atomic_numbers = data.pop("z") + data.atomic_numbers = data.atomic_numbers.long() + data.natoms = data.num_nodes + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + # data.cell = (torch.eye(3) * 1000.0).unsqueeze(dim=0) + # data.cell = torch.tensor( + # [ + # [ + # [8, 0.0000, -0.0000], + # [-0.0000, 12.7363, -0.0000], + # [0.0000, 0.0000, 47.3956], + # ] + # ] + # ) + return data diff --git a/src/jmp/tasks/finetune/spice.py b/src/jmp/tasks/finetune/spice.py new file mode 100644 index 0000000..82b6b92 --- /dev/null +++ b/src/jmp/tasks/finetune/spice.py @@ -0,0 +1,69 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Literal, TypeAlias, final + +import torch +from torch_geometric.data.data import BaseData +from typing_extensions import override + +from ...utils.goc_graph import Cutoffs, Graph, MaxNeighbors +from .energy_forces_base import EnergyForcesConfigBase, EnergyForcesModelBase + +SPICEDataset: TypeAlias = Literal["solvated_amino_acids", "dipeptides"] + + +class SPICEConfig(EnergyForcesConfigBase): + dataset: SPICEDataset + + graph_scalar_targets: list[str] = ["y"] + node_vector_targets: list[str] = ["force"] + + graph_scalar_loss_coefficients: dict[str, float] = {"y": 1.0} + node_vector_loss_coefficients: dict[str, float] = {"force": 100.0} + + +@final +class SPICEModel(EnergyForcesModelBase[SPICEConfig]): + @classmethod + @override + def config_cls(cls): + return SPICEConfig + + @override + def metric_prefix(self) -> str: + return f"spice/{self.config.dataset}" + + @override + def process_aint_graph(self, aint_graph: Graph): + return aint_graph + + @override + def generate_graphs_transform(self, data: BaseData): + return self.generate_graphs( + data, + cutoffs=Cutoffs.from_constant(12.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=False, + ) + + @override + def data_transform(self, data: BaseData): + data = super().data_transform(data) + + data.y = data.pop("formation_energy").view(-1).float() + data.atomic_numbers = data.pop("atomic_numbers").long() + data.natoms = data.num_nodes + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + data.cell = (torch.eye(3) * 1000.0).unsqueeze(dim=0) + return data diff --git a/src/jmp/tasks/pretrain/__init__.py b/src/jmp/tasks/pretrain/__init__.py new file mode 100644 index 0000000..cbf801b --- /dev/null +++ b/src/jmp/tasks/pretrain/__init__.py @@ -0,0 +1,14 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .module import PretrainConfig, PretrainModel + +__all__ = [ + "PretrainConfig", + "PretrainModel", +] diff --git a/src/jmp/tasks/pretrain/module.py b/src/jmp/tasks/pretrain/module.py new file mode 100644 index 0000000..5e08228 --- /dev/null +++ b/src/jmp/tasks/pretrain/module.py @@ -0,0 +1,980 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import math +from collections.abc import Callable +from functools import cache, partial +from logging import getLogger +from typing import Annotated, Generic, Literal, TypeAlias, assert_never, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import pack, rearrange, reduce +from jmp.lightning import Base, BaseConfig, Field, LightningModuleBase, TypedConfig +from jmp.lightning.data.balanced_batch_sampler import BalancedBatchSampler +from jmp.lightning.util.typed import TypedModuleDict, TypedModuleList +from lightning.pytorch.utilities.types import ( + LRSchedulerConfigType, + OptimizerLRSchedulerConfig, +) +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from torch_geometric.data.batch import Batch +from torch_geometric.data.data import BaseData +from torch_geometric.utils import dropout_edge +from torch_scatter import scatter +from torchmetrics import SumMetric +from typing_extensions import TypeVar, override + +from ...datasets.pretrain_lmdb import PretrainDatasetConfig as PretrainDatasetConfigBase +from ...datasets.pretrain_lmdb import PretrainLmdbDataset +from ...models.gemnet.backbone import GemNetOCBackbone, GOCBackboneOutput +from ...models.gemnet.config import BackboneConfig +from ...models.gemnet.layers.base_layers import ScaledSiLU +from ...modules import transforms as T +from ...modules.dataset import dataset_transform as DT +from ...modules.dataset.common import CommonDatasetConfig, wrap_common_dataset +from ...modules.dataset.concat_dataset import MTDatasetConfig, MTSampledDataset +from ...modules.ema import EMAConfig +from ...modules.metrics import FMMetrics +from ...modules.scheduler.linear_warmup_cosine_annealing import ( + LinearWarmupCosineAnnealingLR, +) +from ...modules.transforms.normalize import NormalizationConfig +from ...utils.goc_graph import ( + Cutoffs, + Graph, + MaxNeighbors, + generate_graph, + subselect_graph, + tag_mask, +) +from ..config import ( + EmbeddingConfig, + OptimizerConfig, + OutputConfig, + optimizer_from_config, +) + +log = getLogger(__name__) + + +class LinearWarmupCosineAnnealingSchedulerConfig(TypedConfig): + name: Literal["linear_warmup_cosine_annealing"] = "linear_warmup_cosine_annealing" + + warmup_steps: int = 0 + max_steps: int | None = None + max_epochs: int | None = None + warmup_start_lr_factor: float = 0.0 + min_lr_factor: float = 1.0e-2 + last_step: int = -1 + + +LRSchedulerConfig: TypeAlias = Annotated[ + LinearWarmupCosineAnnealingSchedulerConfig, Field(discriminator="name") +] + + +class PretrainDatasetConfig(PretrainDatasetConfigBase, CommonDatasetConfig): + pass + + +class TaskConfig(TypedConfig): + name: str + """Name of the task.""" + + train_dataset: PretrainDatasetConfig + """Train dataset configuration.""" + + val_dataset: PretrainDatasetConfig + """Validation dataset configuration.""" + + node_energy_reduction: Literal["sum", "mean"] = "sum" + """How to reduce the node energy scalar contributions (to get the total energy).""" + + additional_units: list[str] = [] + """Additional units to log for this task.""" + + energy_loss_scale: float = 1.0 + """Scale factor for the energy loss.""" + force_loss_scale: float = 1.0 + """Scale factor for the force loss.""" + + normalization: dict[str, NormalizationConfig] | None = None + """ + Normalization to apply to the target values. + Each key is the name of the target value + and the value is a dict with the mean and std. + """ + + +class PretrainConfig(BaseConfig): + optimizer: OptimizerConfig + """Optimizer to use.""" + lr_scheduler: LRSchedulerConfig | None = None + """Learning rate scheduler configuration. If None, no learning rate scheduler is used.""" + + activation: Literal[ + "scaled_silu", + "scaled_swish", + "silu", + "swish", + ] = "scaled_silu" + """Activation function to use.""" + + dropout: float | None = None + """The dropout rate to use in GemNet.""" + edge_dropout: float | None = None + """The percentage of edges to drop. If None, no edges are dropped.""" + + embedding: EmbeddingConfig = EmbeddingConfig( + num_elements=BackboneConfig.base().num_elements, + embedding_size=BackboneConfig.base().emb_size_atom, + ) + """Configuration for the embedding layer.""" + backbone: BackboneConfig = BackboneConfig.base() + """Configuration for the backbone.""" + output: OutputConfig = OutputConfig(num_mlps=5) + """Configuration for the output head.""" + + batch_size: int + """Batch size to use.""" + eval_batch_size: int | None = None + """Batch size to use for evaluation. If None, use the same as batch_size.""" + num_workers: int + """Number of workers to use for data loading.""" + pin_memory: bool = True + """Whether to use pin memory for data loading.""" + + shuffle_train: bool = True + """Should we shuffle the training dataset?""" + + shuffle_val: bool = False + """Should we shuffle the validation dataset?""" + + @property + def activation_cls(self): + match self.activation: + case "scaled_silu" | "scaled_swish": + return ScaledSiLU + case "silu" | "swish": + return nn.SiLU + case None: + return nn.Identity + case _: + raise NotImplementedError( + f"Activation {self.activation} is not implemented" + ) + + log_task_losses: bool = True + """Log the loss for each task.""" + log_task_steps_and_epochs: bool = True + """Log the number of steps and epochs for each task.""" + + tasks: list[TaskConfig] + """List of datasets/tasks to train on.""" + mt_dataset: MTDatasetConfig = MTDatasetConfig( + balanced=True, + strict=True, + ) + """Configuration for the multi-task dataset.""" + + exclude_keys: list[str] = [ + "id", # only oc20,oc22 have this + "fid", # only oc20,oc22 have this + "cell_offsets", # only oc20 has this + "edge_index", # only oc20 has this + "absolute_idx", # only ani has this + "target_pos", # only ani has this + "ref_energy", # only ani/geom have this + "pbc", # only ani/transition1x have this + "oc22", # only oc22 has this + "name", + ] + """Keys to exclude when creating a batch from a data list.""" + + train_on_free_atoms_only: bool = False + """Train only on free atoms.""" + + eval_on_free_atoms_only: bool = True + """Evaluate only on free atoms.""" + + energy_loss_reduction: Literal["sum", "mean"] = "mean" + """How to reduce the energy loss. "sum" or "mean".""" + force_loss_reduction: Literal["sum", "mean"] = "mean" + """How to reduce the force loss. "sum" or "mean".""" + + structurewise_loss_reduction: bool = True + """Use the proposed structurewise loss (from the paper) reduction for the force loss.""" + + ema: EMAConfig | None = None + """Configuration for the exponential moving average.""" + + @override + def __post_init__(self): + super().__post_init__() + + self.trainer.use_distributed_sampler = False + + self.backbone.dropout = self.dropout + self.backbone.edge_dropout = self.edge_dropout + + +class Embedding(Base[PretrainConfig], nn.Module): + @override + def __init__(self, hparams: PretrainConfig): + super().__init__(hparams) + + self.atom_embedding = nn.Embedding( + num_embeddings=self.config.embedding.num_elements, + embedding_dim=self.config.embedding.embedding_size, + ) + + @override + def forward(self, data: BaseData): + atomic_numbers = data.atomic_numbers - 1 + x = self.atom_embedding(atomic_numbers) + return x + + +class Output(Base[PretrainConfig], nn.Module): + @override + def __init__(self, hparams: PretrainConfig): + super().__init__(hparams) + + def dims( + emb_size: int, + *, + num_targets: int = self.config.backbone.num_targets, + num_mlps: int = self.config.output.num_mlps, + ): + return ([emb_size] * num_mlps) + [num_targets] + + self.out_energy = TypedModuleList( + [ + self.mlp( + dims(self.config.backbone.emb_size_atom), + activation=self.config.activation_cls, + ) + for _ in self.config.tasks + ] + ) + self.out_forces = TypedModuleList( + [ + self.mlp( + dims(self.config.backbone.emb_size_edge), + activation=self.config.activation_cls, + ) + for _ in self.config.tasks + ] + ) + + @override + def forward(self, data: BaseData, backbone_out: GOCBackboneOutput): + energy = backbone_out["energy"] + forces = backbone_out["forces"] + V_st = backbone_out["V_st"] + idx_t = backbone_out["idx_t"] + + batch: torch.Tensor = data.batch + n_molecules = int(torch.max(batch).item() + 1) + n_atoms = data.atomic_numbers.shape[0] + + energy_list: list[torch.Tensor] = [] + forces_list: list[torch.Tensor] = [] + + for energy_mlp, forces_mlp, task in zip( + self.out_energy, self.out_forces, self.config.tasks + ): + E_t = energy_mlp(energy) # (n_atoms, 1) + E_t = scatter( + E_t, + batch, + dim=0, + dim_size=n_molecules, + reduce=task.node_energy_reduction, + ) + energy_list.append(E_t) # (bsz, 1) + + F_st = forces_mlp(forces) # (n_edges, 1) + F_st = F_st * V_st # (n_edges, 3) + F_t = scatter(F_st, idx_t, dim=0, dim_size=n_atoms, reduce="sum") + forces_list.append(F_t) # (n_atoms, 3) + + E, _ = pack(energy_list, "bsz *") + F, _ = pack(forces_list, "n_atoms p *") + + return E, F + + +TConfig = TypeVar( + "TConfig", bound=PretrainConfig, default=PretrainConfig, infer_variance=True +) + + +class PretrainModel(LightningModuleBase[TConfig], Generic[TConfig]): + @classmethod + @override + def config_cls(cls): + return PretrainConfig + + @staticmethod + def _model_validate_config(config: TConfig): + assert ( + config.activation.lower() == config.backbone.activation.lower() + ), f"{config.activation=} != {config.backbone.activation=}" + + assert ( + config.embedding.num_elements == config.backbone.num_elements + ), f"{config.embedding.num_elements=} != {config.backbone.num_elements=}" + assert ( + config.embedding.embedding_size == config.backbone.emb_size_atom + ), f"{config.embedding.embedding_size=} != {config.backbone.emb_size_atom=}" + + def _construct_backbone(self): + backbone = GemNetOCBackbone(self.config.backbone, **dict(self.config.backbone)) + return backbone + + @override + def __init__(self, hparams: TConfig): + self._model_validate_config(hparams) + super().__init__(hparams) + + # Set up callbacks + if (ema := self.config.ema) is not None: + self.register_callback(lambda: ema.construct_callback()) + + # Set up the model + self.embedding = Embedding(self.config) + self.backbone = self._construct_backbone() + self.output = Output(self.config) + + # Set up the metrics + self.train_metrics = FMMetrics( + { + task.name: {"idx": idx, "additional_units": task.additional_units} + for idx, task in enumerate(self.config.tasks) + }, + denormalize=any(task.normalization for task in self.config.tasks), + free_atoms_only=self.config.eval_on_free_atoms_only, + ) + self.val_metrics = FMMetrics( + { + task.name: {"idx": idx, "additional_units": task.additional_units} + for idx, task in enumerate(self.config.tasks) + }, + denormalize=any(task.normalization for task in self.config.tasks), + free_atoms_only=self.config.eval_on_free_atoms_only, + ) + + # GemNet-OC re-uses some parameters at every layer. + # We need to make sure that these parameters' gradients are + # downscaled by the number of layers so that the gradients + # are not too large. + if self.backbone.shared_parameters: + self.register_shared_parameters(self.backbone.shared_parameters) + + self._train_dataset_sizes: list[int] | None = None + if self.config.log_task_steps_and_epochs: + task_steps: dict[str, SumMetric] = {} + for task in self.config.tasks: + metric = SumMetric() + metric.persistent(True) + task_steps[task.name] = metric + self.task_steps = TypedModuleDict(task_steps) + + def backbone_state_dict(self): + return { + "backbone": self.backbone.state_dict(), + "embedding": self.embedding.atom_embedding.state_dict(), + } + + @override + def on_train_batch_start(self, batch: BaseData, batch_idx: int): + if not self.config.log_task_steps_and_epochs: + return + + assert self._train_dataset_sizes + + task_mask = batch.task_mask # (b, t) + task_idx = reduce(task_mask, "b t -> t", "sum") # (t,) + for idx, task in enumerate(self.config.tasks): + metric = self.task_steps[task.name] + metric(task_idx[idx]) + + step = metric.compute() + self.log(f"train/{task.name}/step", step) + + epoch = step / self._train_dataset_sizes[idx] + self.log(f"train/{task.name}/epoch", epoch) + + @override + def forward(self, batch: BaseData): + h = self.embedding(batch) + out: GOCBackboneOutput = self.backbone(batch, h=h) + return self.output(batch, out) # (n h), (n p h) + + def _task_idx_onehot(self, task_idx: int): + return F.one_hot( + torch.tensor([task_idx], device=self.device, dtype=torch.long), + num_classes=len(self.config.tasks), + ).bool() + + def _force_loss( + self, batch: BaseData, forces: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.debug: + assert forces.shape == batch.force.shape + + pred: torch.Tensor = rearrange(forces, "n p t -> n t p") + target: torch.Tensor = rearrange(batch.force, "n p t -> n t p") + + mask = batch.task_mask # b t + mask = mask[batch.batch] # n t + if self.config.train_on_free_atoms_only: + mask = mask & rearrange(~batch.fixed, "n -> n 1") + + force_loss = F.pairwise_distance(pred, target, p=2.0) # (n, t) + + if (scale := getattr(batch, "force_scale", None)) is not None: + # force_loss_scale: (b,) + scale = scale[batch.batch] # (n, t) + if self.config.train_on_free_atoms_only: + scale = scale[~batch.fixed] + force_loss = force_loss * scale + + if (scale := getattr(batch, "force_scale_node", None)) is not None: + # force_scale_node: (n, t) + if self.config.train_on_free_atoms_only: + scale = scale[~batch.fixed] + force_loss = force_loss * scale + + force_loss = force_loss.masked_fill(~mask, 0.0) + + if self.config.log_task_losses: + with torch.no_grad(): + for task_idx, task in enumerate(self.config.tasks): + task_mask = mask & self._task_idx_onehot(task_idx) + task_force_loss = force_loss.masked_fill(~task_mask, 0.0) + self.log( + f"{task.name}/force_loss", + self._reduce_loss( + task_force_loss, + task_mask, + reduction=self.config.force_loss_reduction, + ), + ) + + # force_loss = self._reduce_force_loss(force_loss, mask) + return force_loss, mask + + def _energy_loss( + self, batch: BaseData, energy: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + mask = batch.task_mask # (b, h) + + energy_loss = F.l1_loss(energy, batch.y, reduction="none") # (b, num_tasks) + + if (scale := getattr(batch, "y_scale", None)) is not None: + energy_loss = energy_loss * scale # (b, t) + + energy_loss = energy_loss.masked_fill(~mask, 0.0) + + if self.config.log_task_losses: + with torch.no_grad(): + for task_idx, task in enumerate(self.config.tasks): + task_mask = mask & self._task_idx_onehot(task_idx) + task_energy_loss = energy_loss.masked_fill(~task_mask, 0.0) + self.log( + f"{task.name}/energy_loss", + self._reduce_loss( + task_energy_loss, + task_mask, + reduction=self.config.energy_loss_reduction, + ), + ) + + return energy_loss, mask + + @staticmethod + def _safe_divide(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + b = b.masked_fill(b == 0.0, 1.0) + return a / b + + def _reduce_loss( + self, + loss: torch.Tensor, + mask: torch.Tensor, + reduction: Literal["sum", "mean"], + ): + match reduction: + case "sum": + loss = reduce(loss, "b t -> ", "sum") + case "mean": + # loss = reduce(loss, "b t -> ", "sum") / reduce(mask, "b t -> ", "sum") + loss = self._safe_divide( + reduce(loss, "b t -> ", "sum"), + reduce(mask, "b t -> ", "sum"), + ) + case _: + raise ValueError(f"Unknown redution: {reduction}") + + return loss + + def compute_losses( + self, batch: BaseData, energy: torch.Tensor, forces: torch.Tensor + ): + # Compute the energy loss + energy_loss, energy_loss_mask = self._energy_loss( + batch, energy + ) # (b, t), (b, t) + energy_loss = self._reduce_loss( + energy_loss, energy_loss_mask, reduction=self.config.energy_loss_reduction + ) + self.log("energy_loss", energy_loss) + + # Compute the force loss + force_loss, force_loss_mask = self._force_loss(batch, forces) + if self.config.structurewise_loss_reduction: + # Compute the per-structure force loss + force_loss = scatter(force_loss, batch.batch, dim=0, reduce="sum") # (b, t) + force_loss_mask_natoms = scatter( + force_loss_mask.float(), batch.batch, dim=0, reduce="sum" + ) # (b, t) + force_loss = self._safe_divide(force_loss, force_loss_mask_natoms) # (b, t) + force_loss_mask = force_loss_mask_natoms > 0.0 # (b, t) + force_loss = self._reduce_loss( + force_loss, force_loss_mask, reduction=self.config.force_loss_reduction + ) + self.log("force_loss", force_loss) + + # Combine the losses + loss = energy_loss + force_loss + self.log("loss", loss) + + return loss + + @override + def training_step(self, batch: BaseData, batch_idx: int): + with self.log_context(prefix="train/"): + energy, forces = self(batch) + + loss = self.compute_losses(batch, energy=energy, forces=forces) + self.log_dict(self.train_metrics(batch, energy=energy, forces=forces)) + + return loss + + @override + def validation_step(self, batch: BaseData, batch_idx: int): + with self.log_context(prefix="val/"): + energy, forces = self(batch) + + metrics = self.val_metrics(batch, energy=energy, forces=forces) + self.log_dict(metrics) + + def configure_lr_scheduler( + self, optimizer: torch.optim.Optimizer + ) -> LRSchedulerConfigType | None: + match self.config.lr_scheduler: + case None: + return None + case LinearWarmupCosineAnnealingSchedulerConfig() as config: + if not (max_steps := config.max_steps): + if max_epochs := config.max_epochs: + _ = self.trainer.estimated_stepping_batches # make sure dataloaders are loaded for self.trainer.num_training_batches + num_steps_per_epoch = math.ceil( + self.trainer.num_training_batches + / self.trainer.accumulate_grad_batches + ) + max_steps = max_epochs * num_steps_per_epoch + else: + max_steps = self.trainer.estimated_stepping_batches + assert math.isfinite(max_steps), f"{max_steps=} is not finite" + max_steps = int(max_steps) + + log.critical(f"Setting {max_steps=} by default.") + + optim_lr = float(optimizer.param_groups[0]["lr"]) + min_lr = optim_lr * config.min_lr_factor + warmup_start_lr = optim_lr * config.warmup_start_lr_factor + lr_scheduler = LinearWarmupCosineAnnealingLR( + optimizer, + warmup_epochs=config.warmup_steps, + max_epochs=max_steps, + warmup_start_lr=warmup_start_lr, + eta_min=min_lr, + last_epoch=config.last_step, + ) + return { + "scheduler": lr_scheduler, + "interval": "step", + "frequency": 1, + "reduce_on_plateau": False, + "strict": True, # type: ignore + } + + case _: + assert_never(self.config.lr_scheduler) + + @override + def configure_optimizers(self): + optimizer = optimizer_from_config([(self.config.optimizer, self.parameters())]) + out: OptimizerLRSchedulerConfig = {"optimizer": optimizer} + if (lr_scheduler := self.configure_lr_scheduler(optimizer)) is not None: + out["lr_scheduler"] = lr_scheduler + + return out + + def _task_dataset(self, task: TaskConfig, training: bool): + config = task.val_dataset if not training else task.train_dataset + dataset = PretrainLmdbDataset(config) + dataset = wrap_common_dataset(dataset, config) + + # Apply data transform to the dataset + if (transform := getattr(self, f"{task.name}_transform")) is None: + raise ValueError(f"Transform not defined for {task.name}") + transform = cast( + Callable[[BaseData], BaseData], partial(transform, training=training) + ) + + # Apply normalization to the dataset + if task.normalization: + log.info(f"Normalizing {task.name} with {task.normalization}") + transform = T.compose([transform, T.normalize(task.normalization)]) + + dataset = DT.transform(dataset, transform) + + return dataset + + def _construct_fm_datasets(self, training: bool): + datasets = [] + for task in self.config.tasks: + datasets.append(self._task_dataset(task, training=training)) + return datasets + + @cache + def train_dataset(self): + datasets = self._construct_fm_datasets(training=True) + self._train_dataset_sizes = [len(d) for d in datasets] + # if self.config.log_task_steps_and_epochs: + dataset = MTSampledDataset( + datasets, + self.config.mt_dataset, + ignore_balancing=False, + num_tasks=len(self.config.tasks), + ) + dataset = DT.transform(dataset, self.train_data_transform) + return dataset + + def representative_batch_for_testing(self, *, n: int, start_index: int = 0): + dataset = self.train_dataset() + data_list = dataset.representative_batch_for_testing( + n=n, start_index=start_index + ) + data_list = [self.train_data_transform(data) for data in data_list] + return data_list + + @cache + def val_dataset(self): + datasets = self._construct_fm_datasets(training=False) + dataset = MTSampledDataset( + datasets, + self.config.mt_dataset, + ignore_balancing=True, + num_tasks=len(self.config.tasks), + ) + dataset = DT.transform(dataset, self.val_data_transform) + return dataset + + def collate_fn(self, data_list: list[BaseData]): + return Batch.from_data_list(data_list, exclude_keys=self.config.exclude_keys) + + def distributed_sampler(self, dataset: Dataset, shuffle: bool): + return DistributedSampler( + dataset, + num_replicas=self.trainer.world_size, + rank=self.trainer.global_rank, + shuffle=shuffle, + ) + + @override + def train_dataloader(self): + dataset = self.train_dataset() + sampler = self.distributed_sampler(dataset, shuffle=self.config.shuffle_train) + batch_sampler = BalancedBatchSampler( + sampler, + batch_size=self.config.batch_size, + device=self.device, + ) + data_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + pin_memory=self.config.pin_memory, + ) + return data_loader + + @override + def val_dataloader(self): + dataset = self.val_dataset() + sampler = self.distributed_sampler(dataset, shuffle=self.config.shuffle_val) + batch_sampler = BalancedBatchSampler( + sampler, + batch_size=self.config.batch_size, + device=self.device, + ) + data_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + num_workers=self.config.num_workers, + pin_memory=self.config.pin_memory, + ) + return data_loader + + def _task_config(self, name: str): + return next((task for task in self.config.tasks if task.name == name), None) + + @staticmethod + def _to_int(value): + return int(value.item() if torch.is_tensor(value) else value) + + def train_data_transform(self, data: BaseData): + data = self.data_transform(data) + return data + + def val_data_transform(self, data: BaseData): + data = self.data_transform(data) + return data + + def data_transform(self, data: BaseData): + data.y = ( + data.y.float() + if torch.is_tensor(data.y) + else torch.tensor(data.y, dtype=torch.float) + ) + + data.fixed = data.fixed.bool() + data.atomic_numbers = data.atomic_numbers.long() + data.natoms = self._to_int(data.natoms) + data.sid = self._to_int(data.sid) + for graph_type in ["main", "a2a", "a2ee2a", "qint"]: + key = f"{graph_type}_num_neighbors" + setattr(data, key, self._to_int(data[key])) + + for attr in ("y", "force"): + key = f"{attr}_scale" + if not hasattr(data, key): + raise ValueError(f"{key=} not found in data") + + # make all tensors contiguous + for key in data.keys(): + if not torch.is_tensor(data[key]): + continue + + data[key] = data[key].contiguous() + + return data + + def _process_aint_graph(self, graph: Graph, *, training: bool): + if self.config.edge_dropout: + graph["edge_index"], mask = dropout_edge( + graph["edge_index"], + p=self.config.edge_dropout, + training=training, + ) + graph["distance"] = graph["distance"][mask] + graph["vector"] = graph["vector"][mask] + graph["cell_offset"] = graph["cell_offset"][mask] + + if "id_swap_edge_index" in graph: + graph["id_swap_edge_index"] = graph["id_swap_edge_index"][mask] + + return graph + + def _generate_graphs( + self, + data: BaseData, + cutoffs: Cutoffs, + max_neighbors: MaxNeighbors, + pbc: bool, + *, + training: bool, + ): + aint_graph = generate_graph( + data, cutoff=cutoffs.aint, max_neighbors=max_neighbors.aint, pbc=pbc + ) + aint_graph = self._process_aint_graph(aint_graph, training=training) + subselect = partial( + subselect_graph, + data, + aint_graph, + cutoff_orig=cutoffs.aint, + max_neighbors_orig=max_neighbors.aint, + ) + main_graph = subselect(cutoffs.main, max_neighbors.main) + aeaint_graph = subselect(cutoffs.aeaint, max_neighbors.aeaint) + qint_graph = subselect(cutoffs.qint, max_neighbors.qint) + + # We can't do this at the data level: This is because the batch collate_fn doesn't know + # that it needs to increment the "id_swap" indices as it collates the data. + # So we do this at the graph level (which is done in the GemNetOC `get_graphs_and_indices` method). + # main_graph = symmetrize_edges(main_graph, num_atoms=data.pos.shape[0]) + qint_graph = tag_mask(data, qint_graph, tags=self.config.backbone.qint_tags) + + graphs = { + "main": main_graph, + "a2a": aint_graph, + "a2ee2a": aeaint_graph, + "qint": qint_graph, + } + + for graph_type, graph in graphs.items(): + graph["num_neighbors"] = graph["edge_index"].shape[1] + for key, value in graph.items(): + setattr(data, f"{graph_type}_{key}", value) + + return data + + def _initial_data_transform(self, data: BaseData): + if not torch.is_tensor(data.y): + data.y = torch.tensor(data.y) + data.y = data.y.view(-1) + + return data + + def oc20_transform(self, data: BaseData, *, training: bool): + data = self._initial_data_transform(data) + + assert ( + config := self._task_config("oc20") + ) is not None, "OC20 task is not configured" + + # convert back these keys into required format for collation + data.natoms = int(data.natoms.item() if torch.is_tensor(data) else data.natoms) + + data.atomic_numbers = data.atomic_numbers.long() + data.tags = data.tags.long() + + data = self._generate_graphs( + data, + cutoffs=Cutoffs.from_constant(12.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=True, + training=training, + ) + + data.y_scale = config.energy_loss_scale + data.force_scale = config.force_loss_scale + + return data + + def oc22_transform(self, data: BaseData, *, training: bool): + data = self._initial_data_transform(data) + + assert ( + config := self._task_config("oc22") + ) is not None, "OC22 task is not configured" + + # convert back these keys into required format for collation + data.natoms = int(data.natoms.item() if torch.is_tensor(data) else data.natoms) + + data.atomic_numbers = data.atomic_numbers.long() + data.tags = data.tags.long() + try: + data.y = torch.tensor(float(data.y)).view(-1) + except BaseException: + data.y = torch.tensor(float(data.y_relaxed)).view(-1) + data.name = "oc22" + + data = self._generate_graphs( + data, + cutoffs=Cutoffs.from_constant(12.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=True, + training=training, + ) + + data.y_scale = config.energy_loss_scale + data.force_scale = config.force_loss_scale + + return data + + @staticmethod + def _set_inf_cell(data: BaseData, max_length: float = 1000.0): + data.cell = (torch.eye(3) * max_length).unsqueeze(dim=0) + return data + + def ani1x_transform(self, data: BaseData, *, training: bool): + data = self._initial_data_transform(data) + + assert ( + config := self._task_config("ani1x") + ) is not None, "ANI1x task is not configured" + + data.y = data.y.view(-1).float() + if not hasattr(data, "sid"): + data.sid = data.absolute_idx + if not hasattr(data, "natoms"): + data.natoms = data.num_nodes + + # data.fixed = torch.ones(data.natoms) + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + data.name = "ani1x" + + data = self._set_inf_cell(data) + data = self._generate_graphs( + data, + cutoffs=Cutoffs.from_constant(8.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=False, + training=training, + ) + + data.y_scale = config.energy_loss_scale + data.force_scale = config.force_loss_scale + + return data + + def transition1x_transform(self, data: BaseData, *, training: bool): + data = self._initial_data_transform(data) + + assert ( + config := self._task_config("transition1x") + ) is not None, "Transition1x task is not configured" + + data.y = data.y.view(-1).float() + if not hasattr(data, "sid"): + data.sid = data.absolute_idx + if not hasattr(data, "natoms"): + data.natoms = data.num_nodes + + # data.fixed = torch.ones(data.natoms) + data.fixed = torch.zeros(data.natoms, dtype=torch.bool) + + data.tags = 2 * torch.ones(data.natoms) + data.tags = data.tags.long() + data.name = "transition1x" + + data = self._set_inf_cell(data) + data = self._generate_graphs( + data, + cutoffs=Cutoffs.from_constant(8.0), + max_neighbors=MaxNeighbors.from_goc_base_proportions(30), + pbc=False, + training=training, + ) + + data.y_scale = config.energy_loss_scale + data.force_scale = config.force_loss_scale + + return data diff --git a/src/jmp/tasks/pretrain/normalization.py b/src/jmp/tasks/pretrain/normalization.py new file mode 100644 index 0000000..259788b --- /dev/null +++ b/src/jmp/tasks/pretrain/normalization.py @@ -0,0 +1,66 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + + +class Normalization: + full = { + "oc20": { + "y": {"mean": 1.305542661963295, "std": 24.901469505465872}, + "force": {"mean": 0.0, "std": 0.5111534595489502}, + }, + "oc22": { + "y": {"mean": 1.232371959806986, "std": 25.229595396538468}, + "force": {"mean": 0.0, "std": 0.25678861141204834}, + }, + "ani1x": { + "y": {"mean": 0.3835804075000375, "std": 2.8700712783472118}, + "force": {"mean": 0.0, "std": 2.131422996520996}, + }, + "transition1x": { + "y": {"mean": 0.03843723322120415, "std": 1.787466168382901}, + "force": {"mean": 0.0, "std": 0.3591422140598297}, + }, + } + + linref = { + "oc20": { + "y": {"mean": -0.7536948323249817, "std": 2.940723180770874}, + "force": {"mean": 0.0, "std": 0.43649423122406006}, + }, + "oc22": { + "y": {"mean": 1.2600674629211426, "std": 25.42051887512207}, + "force": {"mean": 0.0, "std": 0.2522418200969696}, + }, + "ani1x": { + "y": {"mean": 0.3596225082874298, "std": 2.8952934741973877}, + "force": {"mean": 0.0, "std": 2.1361355781555176}, + }, + "transition1x": { + "y": {"mean": -59.33904266357422, "std": 10.360939979553223}, + "force": {"mean": -2.941862021543784e-06, "std": 0.3591934144496918}, + }, + } + + nolinref = { + "oc20": { + "y": {"mean": -359.82421875, "std": 231.93690490722656}, + "force": {"mean": 0.0, "std": 0.43649423122406006}, + }, + "oc22": { + "y": {"mean": -495.55059814453125, "std": 212.4519805908203}, + "force": {"mean": 0.0, "std": 0.2522418200969696}, + }, + "ani1x": { + "y": {"mean": -10700.826171875, "std": 3739.096923828125}, + "force": {"mean": 0.0, "std": 2.1361355781555176}, + }, + "transition1x": { + "y": {"mean": -8254.189453125, "std": 1067.7886962890625}, + "force": {"mean": -2.941862021543784e-06, "std": 0.3591934144496918}, + }, + } diff --git a/src/jmp/utils/finetune_state_dict.py b/src/jmp/utils/finetune_state_dict.py new file mode 100644 index 0000000..94340cc --- /dev/null +++ b/src/jmp/utils/finetune_state_dict.py @@ -0,0 +1,131 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import fnmatch +from logging import getLogger +from pathlib import Path +from typing import Any, TypedDict, cast + +import torch +from typing_extensions import NotRequired + +log = getLogger(__name__) + + +class _LightningCheckpoint(TypedDict): + optimizer_states: NotRequired[list[Any]] + state_dict: dict[str, torch.Tensor] + + +def _get_parameter_list_from_state_dict( + state_dict: dict[str, torch.Tensor], + ignore_scale_factor: bool, + param_dict_prefixes: list[list[str]] | None = None, +): + if param_dict_prefixes: + buffer = list(state_dict.keys()) + ordered_parameters: list[str] = [] + for prefixes in param_dict_prefixes: + to_remove = set[str]() + for key in buffer: + if not any(fnmatch.fnmatch(key, prefix) for prefix in prefixes): + continue + + ordered_parameters.append(key) + to_remove.add(key) # we don't want to remove while iterating + + log.critical(f"Processed {len(to_remove)} keys under prefixes: {prefixes}") + for key in to_remove: + buffer.remove(key) + + if buffer: + ordered_parameters.extend(buffer) + else: + ordered_parameters = list(state_dict.keys()) + + parameters_list: list[str] = [] + for k in ordered_parameters: + if k.endswith("rbf.offset") or k.endswith("rbf.temps"): + continue + if ".seq_energy_pre." in k: + continue + if k.startswith("task_steps"): + continue + if ignore_scale_factor and k.endswith("scale_factor"): + continue + parameters_list.append(k) + return parameters_list + + +def retreive_state_dict_for_finetuning( + ckpt_path: str | Path, + load_emas: bool = True, + ignore_scale_factor: bool = True, + param_dict_prefixes: list[list[str]] | None = None, +): + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + raise FileNotFoundError(ckpt_path) + + ckpt = torch.load(ckpt_path, map_location="cpu") + assert isinstance(ckpt, dict), type(ckpt) + + ckpt = cast(_LightningCheckpoint, ckpt) + state_dict = retreive_ft_state_dict_from_loaded_ckpt( + ckpt, load_emas, ignore_scale_factor, param_dict_prefixes + ) + log.critical(f"Loaded state dict from {ckpt_path}") + return state_dict + + +def retreive_ft_state_dict_from_loaded_ckpt( + ckpt: _LightningCheckpoint, + load_emas: bool = True, + ignore_scale_factor: bool = True, + param_dict_prefixes: list[list[str]] | None = None, +): + state_dict = ckpt["state_dict"].copy() + if load_emas: + optimizer_states = ckpt.get("optimizer_states") + assert optimizer_states is not None, "optimizer_states must be present" + + assert len(optimizer_states) == 1, f"{len(optimizer_states)=} != 1" + optimizer_state = optimizer_states[0] + assert isinstance(optimizer_state, dict), type(optimizer_state) + assert (ema := optimizer_state.get("ema")) is not None, "ema must be present" + assert isinstance(ema, tuple), type(ema) + + ema = cast(list[torch.Tensor], list(ema)) + parameters_list = _get_parameter_list_from_state_dict( + state_dict, + ignore_scale_factor, + param_dict_prefixes, + ) + + assert len(ema) == len( + parameters_list + ), f"{len(ema)=} != {len(parameters_list)=}" + + for i, (param, ema) in enumerate(zip(parameters_list, ema)): + existing_param = state_dict[param] + assert ( + existing_param.shape == ema.shape + ), f"{existing_param.shape=} != {ema.shape=} for {param=} at index {i}" + assert ( + existing_param.dtype == ema.dtype + ), f"{existing_param.dtype=} != {ema.dtype=} for {param=} at index {i}" + state_dict[param] = ema + log.info(f"Loaded EMA for {param}") + + log.critical(f"Loaded {len(parameters_list)} EMA parameters") + + return state_dict + + +def filter_state_dict(state_dict: dict[str, torch.Tensor], prefix: str): + return {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)} diff --git a/src/jmp/utils/fit_scales.py b/src/jmp/utils/fit_scales.py new file mode 100644 index 0000000..a137b90 --- /dev/null +++ b/src/jmp/utils/fit_scales.py @@ -0,0 +1,136 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from collections.abc import Callable +from logging import getLogger +from pathlib import Path +from typing import Literal + +import torch +from jmp.lightning import BaseConfig as Config +from jmp.lightning import LightningModuleBase, Trainer +from jmp.models.gemnet.backbone import GemNetOCBackbone +from jmp.modules.scaling import ScaleFactor +from typing_extensions import TypeVar + +log = getLogger(__name__) + +TModel = TypeVar("TModel", bound=LightningModuleBase, infer_variance=True) + + +def fit_scales( + config: Config, + model_cls: Callable[[Config], TModel], + out_path: Path, + *, + backbone: Callable[[TModel], GemNetOCBackbone], + num_batches: int = 16, + fitted_mode: Literal["replace", "ignore", "error"] = "replace", + replace_out_path_if_exists: bool = False, +): + config = copy.deepcopy(config) + + # config.trainer.fast_dev_run = num_batches + config.trainer.precision = "32-true" + config.trainer.logger = False + config.trainer.max_steps = num_batches + config.trainer.limit_val_batches = num_batches + config.trainer.num_sanity_val_steps = 0 + + with Trainer.context(config): + model = model_cls(config) + + scale_factors = { + name: module + for name, module in backbone(model).named_modules() + if isinstance(module, ScaleFactor) + } + + # region detect fitted/unfitted factors + fitted_scale_factors = [ + f"{name}: {module.scale_factor.item():.3f}" + for name, module in scale_factors.items() + if module.fitted + ] + unfitted_scale_factors = [ + name for name, module in scale_factors.items() if not module.fitted + ] + fitted_scale_factors_str = ", ".join(fitted_scale_factors) + log.info(f"Fitted scale factors: [{fitted_scale_factors_str}]") + unfitted_scale_factors_str = ", ".join(unfitted_scale_factors) + log.info(f"Unfitted scale factors: [{unfitted_scale_factors_str}]") + + if fitted_scale_factors: + match fitted_mode: + case "replace": + log.info("Replacing fitted scale factors with new ones.") + case "ignore": + log.info("Ignoring fitted scale factors.") + case "error": + log.error("Found fitted scale factors.") + log.error("Exiting script.") + return + case _: + raise ValueError(f"Unknown fitted_mode: {fitted_mode}") + # endregion + + log.info( + f"Output path for fitted scale factors: {out_path}, {out_path.exists()=}" + ) + if out_path.exists() and not replace_out_path_if_exists: + raise FileExistsError(f"Output path already exists: {out_path}") + + # region reset the scale factors if mode == "all" + if fitted_mode == "replace": + log.info("Fitting all scale factors and resetting the fitted ones.") + for name, scale_factor in scale_factors.items(): + if scale_factor.fitted: + log.info( + f"{name} is already fitted in the checkpoint, resetting it. {scale_factor.scale_factor}" + ) + scale_factor.reset_() + # endregion + + # loop over the scale factors in the computation order + # and fit them one by one + log.info("Start fitting") + trainer = Trainer(config) + + for name, module in scale_factors.items(): + try: + if module.fitted and fitted_mode == "ignore": + log.info(f"Skipping {name} (already fitted)") + continue + + log.info(f"Fitting {name}...") + with module.fit_context_(): + _ = trainer.validate(model, verbose=False) + stats, ratio, value = module.fit_() + + log.info( + f"Variable: {name}, " + f"Var_in: {stats['variance_in']:.3f}, " + f"Var_out: {stats['variance_out']:.3f}, " + f"Ratio: {ratio:.3f} => Scaling factor: {value:.3f}" + ) + except BaseException as e: + raise RuntimeError(f"Failed to fit {name}") from e + + # make sure all scale factors are fitted + for name, module in scale_factors.items(): + if not module.fitted: + raise RuntimeError(f"Scale factor {name} is not fitted.") + + # region save the scale factors to the checkpoint file + scale_factors_out = { + name: module.scale_factor.clone() for name, module in scale_factors.items() + } + torch.save(scale_factors_out, out_path) + log.info(f"Saved results to: {out_path}") + # endregion diff --git a/src/jmp/utils/goc_graph.py b/src/jmp/utils/goc_graph.py new file mode 100644 index 0000000..2241290 --- /dev/null +++ b/src/jmp/utils/goc_graph.py @@ -0,0 +1,557 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections.abc import Callable +from dataclasses import dataclass +from functools import wraps +from typing import ParamSpec, TypedDict, cast + +import numpy as np +import torch +from torch_geometric.data.batch import Batch +from torch_geometric.data.data import BaseData +from torch_geometric.nn import radius_graph +from torch_geometric.utils import sort_edge_index +from torch_scatter import segment_coo +from typing_extensions import NotRequired + +from ..models.gemnet.utils import ( + get_edge_id, + get_max_neighbors_mask, + mask_neighbors, + repeat_blocks, +) +from .ocp import get_pbc_distances +from .radius_graph import radius_graph_pbc + + +class Graph(TypedDict): + edge_index: torch.Tensor # 2 e + distance: torch.Tensor # e + vector: torch.Tensor # e 3 + cell_offset: torch.Tensor # e 3 + num_neighbors: torch.Tensor # b + + cutoff: torch.Tensor # b + max_neighbors: torch.Tensor # b + + id_swap_edge_index: NotRequired[torch.Tensor] # e + + +@dataclass(frozen=True, kw_only=True) +class Cutoffs: + main: float + aeaint: float + qint: float + aint: float + + @classmethod + def from_constant(cls, value: float): + return cls(main=value, aeaint=value, qint=value, aint=value) + + +@dataclass(frozen=True, kw_only=True) +class MaxNeighbors: + main: int + aeaint: int + qint: int + aint: int + + @classmethod + def from_goc_base_proportions(cls, max_neighbors: int): + """ + GOC base proportions: + max_neighbors: 30 + max_neighbors_qint: 8 + max_neighbors_aeaint: 20 + max_neighbors_aint: 1000 + """ + return cls( + main=max_neighbors, + aeaint=int(max_neighbors * 20 / 30), + qint=int(max_neighbors * 8 / 30), + aint=int(max_neighbors * 1000 / 30), + ) + + +def _select_symmetric_edges(tensor, mask, reorder_idx, opposite_neg): + """Use a mask to remove values of removed edges and then + duplicate the values for the correct edge direction. + + Arguments + --------- + tensor: torch.Tensor + Values to symmetrize for the new tensor. + mask: torch.Tensor + Mask defining which edges go in the correct direction. + reorder_idx: torch.Tensor + Indices defining how to reorder the tensor values after + concatenating the edge values of both directions. + opposite_neg: bool + Whether the edge in the opposite direction should use the + negative tensor value. + + Returns + ------- + tensor_ordered: torch.Tensor + A tensor with symmetrized values. + """ + # Mask out counter-edges + tensor_directed = tensor[mask] + # Concatenate counter-edges after normal edges + sign = 1 - 2 * opposite_neg + tensor_cat = torch.cat([tensor_directed, sign * tensor_directed]) + # Reorder everything so the edges of every image are consecutive + tensor_ordered = tensor_cat[reorder_idx] + return tensor_ordered + + +def symmetrize_edges(graph: Graph, num_atoms: int): + """ + Symmetrize edges to ensure existence of counter-directional edges. + + Some edges are only present in one direction in the data, + since every atom has a maximum number of neighbors. + We only use i->j edges here. So we lose some j->i edges + and add others by making it symmetric. + """ + new_graph = graph.copy() + + # Generate mask + mask_sep_atoms = graph["edge_index"][0] < graph["edge_index"][1] + # Distinguish edges between the same (periodic) atom by ordering the cells + cell_earlier = ( + (graph["cell_offset"][:, 0] < 0) + | ((graph["cell_offset"][:, 0] == 0) & (graph["cell_offset"][:, 1] < 0)) + | ( + (graph["cell_offset"][:, 0] == 0) + & (graph["cell_offset"][:, 1] == 0) + & (graph["cell_offset"][:, 2] < 0) + ) + ) + mask_same_atoms = graph["edge_index"][0] == graph["edge_index"][1] + mask_same_atoms &= cell_earlier + mask = mask_sep_atoms | mask_same_atoms + + # Mask out counter-edges + edge_index_directed = graph["edge_index"][mask[None, :].expand(2, -1)].view(2, -1) + + # Concatenate counter-edges after normal edges + edge_index_cat = torch.cat( + [edge_index_directed, edge_index_directed.flip(0)], + dim=1, + ) + + # Count remaining edges per image + batch_edge = torch.repeat_interleave( + torch.arange( + graph["num_neighbors"].size(0), + device=graph["edge_index"].device, + ), + graph["num_neighbors"], + ) + batch_edge = batch_edge[mask] + # segment_coo assumes sorted batch_edge + # Factor 2 since this is only one half of the edges + ones = batch_edge.new_ones(1).expand_as(batch_edge) + new_graph["num_neighbors"] = 2 * segment_coo( + ones, batch_edge, dim_size=graph["num_neighbors"].size(0) + ) + + # Create indexing array + edge_reorder_idx = repeat_blocks( + torch.div(new_graph["num_neighbors"], 2, rounding_mode="floor"), + repeats=2, + continuous_indexing=True, + repeat_inc=edge_index_directed.size(1), + ) + + # Reorder everything so the edges of every image are consecutive + new_graph["edge_index"] = edge_index_cat[:, edge_reorder_idx] + new_graph["cell_offset"] = _select_symmetric_edges( + graph["cell_offset"], mask, edge_reorder_idx, True + ) + new_graph["distance"] = _select_symmetric_edges( + graph["distance"], mask, edge_reorder_idx, False + ) + new_graph["vector"] = _select_symmetric_edges( + graph["vector"], mask, edge_reorder_idx, True + ) + + # Indices for swapping c->a and a->c (for symmetric MP) + # To obtain these efficiently and without any index assumptions, + # we get order the counter-edge IDs and then + # map this order back to the edge IDs. + # Double argsort gives the desired mapping + # from the ordered tensor to the original tensor. + edge_ids = get_edge_id(new_graph["edge_index"], new_graph["cell_offset"], num_atoms) + order_edge_ids = torch.argsort(edge_ids) + inv_order_edge_ids = torch.argsort(order_edge_ids) + edge_ids_counter = get_edge_id( + new_graph["edge_index"].flip(0), + -new_graph["cell_offset"], + num_atoms, + ) + order_edge_ids_counter = torch.argsort(edge_ids_counter) + id_swap_edge_index = order_edge_ids_counter[inv_order_edge_ids] + + new_graph["id_swap_edge_index"] = id_swap_edge_index + + return cast(Graph, new_graph) + + +def tag_mask(data: BaseData, graph: Graph, *, tags: list[int]): + tags_ = torch.tensor(tags, dtype=torch.long, device=data.tags.device) + + # Only use quadruplets for certain tags + tags_s = data.tags[graph["edge_index"][0]] + tags_t = data.tags[graph["edge_index"][1]] + tag_mask_s = (tags_s[..., None] == tags_).any(dim=-1) + tag_mask_t = (tags_t[..., None] == tags_).any(dim=-1) + tag_mask = tag_mask_s | tag_mask_t + + graph["edge_index"] = graph["edge_index"][:, tag_mask] + graph["cell_offset"] = graph["cell_offset"][tag_mask, :] + graph["distance"] = graph["distance"][tag_mask] + graph["vector"] = graph["vector"][tag_mask, :] + + return graph + + +def _generate_graph( + data: BaseData, + *, + cutoff: float, + max_neighbors: int, + pbc: bool, +): + if pbc: + edge_index, cell_offsets, neighbors = radius_graph_pbc( + data, cutoff, max_neighbors + ) + + out = get_pbc_distances( + data.pos, + edge_index, + data.cell, + cell_offsets, + neighbors, + return_offsets=True, + return_distance_vec=True, + ) + + edge_index: torch.Tensor = out["edge_index"] + edge_dist: torch.Tensor = out["distances"] + cell_offset_distances: torch.Tensor = out["offsets"] + distance_vec: torch.Tensor = out["distance_vec"] + else: + edge_index = radius_graph( + data.pos, + r=cutoff, + batch=data.batch, + max_num_neighbors=max_neighbors, + ) + + j, i = edge_index + distance_vec = data.pos[j] - data.pos[i] + + edge_dist = distance_vec.norm(dim=-1) + cell_offsets = torch.zeros(edge_index.shape[1], 3, device=data.pos.device) + cell_offset_distances = torch.zeros_like(cell_offsets, device=data.pos.device) + neighbors = edge_index.shape[1] + + return ( + edge_index, + edge_dist, + distance_vec, + cell_offsets, + cell_offset_distances, + neighbors, + ) + + +def generate_graph( + data: BaseData, + *, + cutoff: float, + max_neighbors: int, + pbc: bool, + symmetrize: bool = False, + filter_tags: list[int] | None = None, + sort_edges: bool = False, +): + ( + edge_index, + edge_dist, + distance_vec, + cell_offsets, + _, # cell offset distances + num_neighbors, + ) = _generate_graph( + data, + cutoff=cutoff, + max_neighbors=max_neighbors, + pbc=pbc, + ) + # These vectors actually point in the opposite direction. + # But we want to use col as idx_t for efficient aggregation. + edge_vector = -distance_vec / edge_dist[:, None] + # cell_offsets = -cell_offsets # a - c + offset + + graph: Graph = { + "edge_index": edge_index, + "distance": edge_dist, + "vector": edge_vector, + "cell_offset": cell_offsets, + "num_neighbors": num_neighbors, + "cutoff": torch.tensor(cutoff, dtype=torch.float, device=data.pos.device), + "max_neighbors": torch.tensor( + max_neighbors, dtype=torch.long, device=data.pos.device + ), + } + + if symmetrize: + graph = symmetrize_edges(graph, data.pos.shape[0]) + + if filter_tags is not None: + graph = tag_mask(data, graph, tags=filter_tags) + + if sort_edges: + ( + graph["edge_index"], + [ + graph["distance"], + graph["vector"], + graph["cell_offset"], + ], + ) = sort_edge_index( + graph["edge_index"], + [ + graph["distance"], + graph["vector"], + graph["cell_offset"], + ], + num_nodes=data.pos.shape[0], + sort_by_row=False, + ) + + graph["num_neighbors"] = torch.full_like( + graph["num_neighbors"], graph["edge_index"].shape[1] + ) + + return graph + + +def _subselect_edges( + data: BaseData, + graph: Graph, + cutoff: float | None = None, + max_neighbors: int | None = None, +): + """Subselect edges using a stricter cutoff and max_neighbors.""" + subgraph = graph.copy() + + if cutoff is not None: + edge_mask = subgraph["distance"] <= cutoff + + subgraph["edge_index"] = subgraph["edge_index"][:, edge_mask] + subgraph["cell_offset"] = subgraph["cell_offset"][edge_mask] + subgraph["num_neighbors"] = mask_neighbors(subgraph["num_neighbors"], edge_mask) + subgraph["distance"] = subgraph["distance"][edge_mask] + subgraph["vector"] = subgraph["vector"][edge_mask] + + if max_neighbors is not None: + subgraph["max_neighbors"] = torch.tensor( + max_neighbors, dtype=torch.long, device=data.pos.device + ) + edge_mask, subgraph["num_neighbors"] = get_max_neighbors_mask( + natoms=torch.tensor([data.natoms], dtype=torch.long, device=data.pos.device) + if not torch.is_tensor(data.natoms) + else data.natoms.view(-1), + index=subgraph["edge_index"][1], + atom_distance=subgraph["distance"], + max_num_neighbors_threshold=max_neighbors, + ) + if not torch.all(edge_mask): + subgraph["edge_index"] = subgraph["edge_index"][:, edge_mask] + subgraph["cell_offset"] = subgraph["cell_offset"][edge_mask] + subgraph["distance"] = subgraph["distance"][edge_mask] + subgraph["vector"] = subgraph["vector"][edge_mask] + + empty_image = subgraph["num_neighbors"] == 0 + if torch.any(empty_image): + raise ValueError(f"An image has no neighbors: {data}") + return subgraph + + +def subselect_graph( + data: BaseData, + graph: Graph, + cutoff: float, + max_neighbors: int, + cutoff_orig: float, + max_neighbors_orig: int, +): + """If the new cutoff and max_neighbors is different from the original, + subselect the edges of a given graph. + """ + # Check if embedding edges are different from interaction edges + if np.isclose(cutoff, cutoff_orig): + select_cutoff = None + else: + select_cutoff = cutoff + if max_neighbors == max_neighbors_orig: + select_neighbors = None + else: + select_neighbors = max_neighbors + + graph = _subselect_edges( + data=data, + graph=graph, + cutoff=select_cutoff, + max_neighbors=select_neighbors, + ) + return graph + + +def generate_graphs( + data: BaseData, + *, + cutoffs: Cutoffs | Callable[[BaseData], Cutoffs], + max_neighbors: MaxNeighbors | Callable[[BaseData], MaxNeighbors], + pbc: bool, + symmetrize_main: bool = False, + qint_tags: list[int] | None = [1, 2], +): + if callable(cutoffs): + cutoffs = cutoffs(data) + if callable(max_neighbors): + max_neighbors = max_neighbors(data) + + assert cutoffs.main <= cutoffs.aint + assert cutoffs.aeaint <= cutoffs.aint + assert cutoffs.qint <= cutoffs.aint + + assert max_neighbors.main <= max_neighbors.aint + assert max_neighbors.aeaint <= max_neighbors.aint + assert max_neighbors.qint <= max_neighbors.aint + + main_graph = generate_graph( + data, + cutoff=cutoffs.main, + max_neighbors=max_neighbors.main, + pbc=pbc, + symmetrize=symmetrize_main, + ) + a2a_graph = generate_graph( + data, + cutoff=cutoffs.aint, + max_neighbors=max_neighbors.aint, + pbc=pbc, + ) + a2ee2a_graph = generate_graph( + data, + cutoff=cutoffs.aeaint, + max_neighbors=max_neighbors.aeaint, + pbc=pbc, + ) + qint_graph = generate_graph( + data, + cutoff=cutoffs.qint, + max_neighbors=max_neighbors.qint, + pbc=pbc, + filter_tags=qint_tags, + ) + + graphs = { + "main": main_graph, + "a2a": a2a_graph, + "a2ee2a": a2ee2a_graph, + "qint": qint_graph, + } + return graphs + + +P = ParamSpec("P") + + +def with_goc_graphs( + cutoffs: Cutoffs | Callable[[BaseData], Cutoffs], + max_neighbors: MaxNeighbors | Callable[[BaseData], MaxNeighbors], + pbc: bool, + symmetrize_main: bool = False, + qint_tags: list[int] | None = [1, 2], +): + def decorator(func: Callable[P, BaseData]): + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> BaseData: + data = func(*args, **kwargs) + + graphs = generate_graphs( + data, + cutoffs=cutoffs, + max_neighbors=max_neighbors, + pbc=pbc, + symmetrize_main=symmetrize_main, + qint_tags=qint_tags, + ) + for graph_type, graph in graphs.items(): + for key, value in graph.items(): + setattr(data, f"{graph_type}_{key}", value) + + return data + + return wrapper + + return decorator + + +class Graphs(TypedDict): + main: Graph + a2a: Graph + a2ee2a: Graph + qint: Graph + + +GRAPH_TYPES = ["main", "a2a", "a2ee2a", "qint"] + + +def graphs_from_batch(data: BaseData | Batch) -> Graphs: + global GRAPH_TYPES + + graphs = { + graph_type: { + "edge_index": getattr(data, f"{graph_type}_edge_index"), + "distance": getattr(data, f"{graph_type}_distance"), + "vector": getattr(data, f"{graph_type}_vector"), + "cell_offset": getattr(data, f"{graph_type}_cell_offset"), + "num_neighbors": getattr(data, f"{graph_type}_num_neighbors", None), + "cutoff": getattr(data, f"{graph_type}_cutoff", None), + "max_neighbors": getattr(data, f"{graph_type}_max_neighbors", None), + "id_swap_edge_index": getattr( + data, f"{graph_type}_id_swap_edge_index", None + ), + } + for graph_type in GRAPH_TYPES + } + # remove None values + graphs = { + graph_type: {key: value for key, value in graph.items() if value is not None} + for graph_type, graph in graphs.items() + } + return cast(Graphs, graphs) + + +def graphs_to_batch(data: BaseData | Batch, graphs: Graphs): + global GRAPH_TYPES + + for graph_type in GRAPH_TYPES: + for key, value in graphs[graph_type].items(): + setattr(data, f"{graph_type}_{key}", value) + + return data diff --git a/src/jmp/utils/ocp.py b/src/jmp/utils/ocp.py new file mode 100644 index 0000000..220a874 --- /dev/null +++ b/src/jmp/utils/ocp.py @@ -0,0 +1,66 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import cast + +import torch +import torch_geometric +from torch_geometric.data.data import BaseData, Data + + +def pyg2_data_transform(data: BaseData): + """ + if we're on the new pyg (2.0 or later) and if the Data stored is in older format + we need to convert the data to the new format + """ + if torch_geometric.__version__ >= "2.0" and "_store" not in data.__dict__: + data = Data(**{k: v for k, v in data.__dict__.items() if v is not None}) + data = cast(BaseData, data) + + return data + + +def get_pbc_distances( + pos, + edge_index, + cell, + cell_offsets, + neighbors, + return_offsets=False, + return_distance_vec=False, +): + row, col = edge_index + + distance_vectors = pos[row] - pos[col] + + # correct for pbc + neighbors = neighbors.to(cell.device) + cell = torch.repeat_interleave(cell, neighbors, dim=0) + offsets = cell_offsets.float().view(-1, 1, 3).bmm(cell.float()).view(-1, 3) + distance_vectors += offsets + + # compute distances + distances = distance_vectors.norm(dim=-1) + + # redundancy: remove zero distances + nonzero_idx = torch.arange(len(distances), device=distances.device)[distances != 0] + edge_index = edge_index[:, nonzero_idx] + distances = distances[nonzero_idx] + + out = { + "edge_index": edge_index, + "distances": distances, + } + + if return_distance_vec: + out["distance_vec"] = distance_vectors[nonzero_idx] + + if return_offsets: + out["offsets"] = offsets[nonzero_idx] + + return out diff --git a/src/jmp/utils/param_specific_util.py b/src/jmp/utils/param_specific_util.py new file mode 100644 index 0000000..b26b9b4 --- /dev/null +++ b/src/jmp/utils/param_specific_util.py @@ -0,0 +1,88 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from typing import Any, cast + +from jmp.tasks.finetune.base import ( + FinetuneConfigBase, + ParamSpecificOptimizerConfig, + WarmupCosRLPConfig, +) +from typing_extensions import TypeVar + + +def PARAMETER_PATTERNS(num_blocks: int): + return { + "embedding": ["embedding.*"], + "additional_embedding": ["additional_embedding.*"], + "bases": ["backbone.bases.*"], + # "all_int_blocks": ["backbone.int_blocks.*"], + **{ + f"int_blocks_{i}": [f"backbone.int_blocks.{i}.*"] for i in range(num_blocks) + }, + # "all_out_blocks": ["backbone.out_blocks.*"], + **{ + f"out_blocks_{i}": [f"backbone.out_blocks.{i}.*"] + for i in range(num_blocks + 1) + }, + **{ + f"blocks_{i}": [ + f"backbone.int_blocks.{i}.*", + f"backbone.out_blocks.{i+1}.*", + *(["backbone.out_blocks.0.*"] if i == 0 else []), + ] + for i in range(num_blocks) + }, + "out_mlp_E": ["backbone.out_mlp.E.*"], + } + + +TConfig = TypeVar("TConfig", infer_variance=True) + + +def make_parameter_specific_optimizer_config( + config: FinetuneConfigBase, + num_blocks: int, + max_lr_scales: dict[str, float], +): + base_lr = config.optimizer.lr + + parameter_specific_optimizers: list[ParamSpecificOptimizerConfig] = [] + max_lr_scales = cast(dict[str, Any], max_lr_scales) + for name, lr_scale in max_lr_scales.items(): + assert isinstance(lr_scale, float), f"max_lr_scales[{name}] must be float" + + optimizer = copy.deepcopy(config.optimizer) + optimizer.lr = base_lr * lr_scale + + lrs = None + match config.lr_scheduler: + case WarmupCosRLPConfig(): + lrs = copy.deepcopy(config.lr_scheduler) + # We now scale down the cos annealing min LR factor + # so that the final LR is the same as the original config. + lrs.min_lr_factor = lrs.min_lr_factor / lr_scale + lrs.min_lr_factor = max(0.01, min(0.99, lrs.min_lr_factor)) + case _: + raise ValueError( + "You must set config.lr_scheduler to WarmupCosRLPConfig to use parameter specific optimizers." + ) + + assert ( + (parameter_patterns := PARAMETER_PATTERNS(num_blocks).get(name)) is not None + ), f"PARAMETER_PATTERNS[{name}] is None. You must set PARAMETER_PATTERNS[{name}]" + parameter_specific_optimizers.append( + ParamSpecificOptimizerConfig( + paremeter_patterns=parameter_patterns, + optimizer=optimizer, + lr_scheduler=lrs, + ) + ) + + return parameter_specific_optimizers diff --git a/src/jmp/utils/radius_graph.py b/src/jmp/utils/radius_graph.py new file mode 100644 index 0000000..da80b71 --- /dev/null +++ b/src/jmp/utils/radius_graph.py @@ -0,0 +1,253 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import torch +from torch_geometric.data.batch import Batch +from torch_scatter import segment_coo, segment_csr + + +def radius_graph_pbc( + data: Batch, + radius: float, + max_num_neighbors_threshold: int, + pbc: list[bool] = [True, True, True], +): + device = data.pos.device + + if isinstance(data.natoms, int): + data.natoms = torch.tensor([data.natoms], device=device) + + batch_size = len(data.natoms) + + if hasattr(data, "pbc"): + data.pbc = torch.atleast_2d(data.pbc) + for i in range(3): + if not torch.any(data.pbc[:, i]).item(): + pbc[i] = False + elif torch.all(data.pbc[:, i]).item(): + pbc[i] = True + else: + raise RuntimeError( + "Different structures in the batch have different PBC configurations. This is not currently supported." + ) + + # position of the atoms + atom_pos = data.pos + + # Before computing the pairwise distances between atoms, first create a list of atom indices to compare for the entire batch + num_atoms_per_image = data.natoms + num_atoms_per_image_sqr = (num_atoms_per_image**2).long() + + # index offset between images + index_offset = torch.cumsum(num_atoms_per_image, dim=0) - num_atoms_per_image + + index_offset_expand = torch.repeat_interleave(index_offset, num_atoms_per_image_sqr) + num_atoms_per_image_expand = torch.repeat_interleave( + num_atoms_per_image, num_atoms_per_image_sqr + ) + + # Compute a tensor containing sequences of numbers that range from 0 to num_atoms_per_image_sqr for each image + # that is used to compute indices for the pairs of atoms. This is a very convoluted way to implement + # the following (but 10x faster since it removes the for loop) + # for batch_idx in range(batch_size): + # batch_count = torch.cat([batch_count, torch.arange(num_atoms_per_image_sqr[batch_idx], device=device)], dim=0) + num_atom_pairs = torch.sum(num_atoms_per_image_sqr) + index_sqr_offset = ( + torch.cumsum(num_atoms_per_image_sqr, dim=0) - num_atoms_per_image_sqr + ) + index_sqr_offset = torch.repeat_interleave( + index_sqr_offset, num_atoms_per_image_sqr + ) + atom_count_sqr = torch.arange(num_atom_pairs, device=device) - index_sqr_offset + + # Compute the indices for the pairs of atoms (using division and mod) + # If the systems get too large this apporach could run into numerical precision issues + index1 = ( + torch.div(atom_count_sqr, num_atoms_per_image_expand, rounding_mode="floor") + ) + index_offset_expand + index2 = (atom_count_sqr % num_atoms_per_image_expand) + index_offset_expand + # Get the positions for each atom + pos1 = torch.index_select(atom_pos, 0, index1) + pos2 = torch.index_select(atom_pos, 0, index2) + + # Calculate required number of unit cells in each direction. + # Smallest distance between planes separated by a1 is + # 1 / ||(a2 x a3) / V||_2, since a2 x a3 is the area of the plane. + # Note that the unit cell volume V = a1 * (a2 x a3) and that + # (a2 x a3) / V is also the reciprocal primitive vector + # (crystallographer's definition). + + cross_a2a3 = torch.cross(data.cell[:, 1], data.cell[:, 2], dim=-1) + cell_vol = torch.sum(data.cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) + + if pbc[0]: + inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) + rep_a1 = torch.ceil(radius * inv_min_dist_a1) + else: + rep_a1 = data.cell.new_zeros(1) + + if pbc[1]: + cross_a3a1 = torch.cross(data.cell[:, 2], data.cell[:, 0], dim=-1) + inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) + rep_a2 = torch.ceil(radius * inv_min_dist_a2) + else: + rep_a2 = data.cell.new_zeros(1) + + if pbc[2]: + cross_a1a2 = torch.cross(data.cell[:, 0], data.cell[:, 1], dim=-1) + inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) + rep_a3 = torch.ceil(radius * inv_min_dist_a3) + else: + rep_a3 = data.cell.new_zeros(1) + + # Take the max over all images for uniformity. This is essentially padding. + # Note that this can significantly increase the number of computed distances + # if the required repetitions are very different between images + # (which they usually are). Changing this to sparse (scatter) operations + # might be worth the effort if this function becomes a bottleneck. + max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] + + # Tensor of unit cells + cells_per_dim = [ + torch.arange(-rep, rep + 1, device=device, dtype=torch.float) for rep in max_rep + ] + unit_cell = torch.cartesian_prod(*cells_per_dim) + num_cells = len(unit_cell) + unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat(len(index2), 1, 1) + unit_cell = torch.transpose(unit_cell, 0, 1) + unit_cell_batch = unit_cell.view(1, 3, num_cells).expand(batch_size, -1, -1) + + # Compute the x, y, z positional offsets for each cell in each image + data_cell = torch.transpose(data.cell, 1, 2) + pbc_offsets = torch.bmm(data_cell, unit_cell_batch) + pbc_offsets_per_atom = torch.repeat_interleave( + pbc_offsets, num_atoms_per_image_sqr, dim=0 + ) + + # Expand the positions and indices for the 9 cells + pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) + pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) + index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) + index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) + # Add the PBC offsets for the second atom + pos2 = pos2 + pbc_offsets_per_atom + + # Compute the squared distance between atoms + atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) + atom_distance_sqr = atom_distance_sqr.view(-1) + + # Remove pairs that are too far apart + mask_within_radius = torch.le(atom_distance_sqr, radius * radius) + # Remove pairs with the same atoms (distance = 0.0) + mask_not_same = torch.gt(atom_distance_sqr, 0.0001) + mask = torch.logical_and(mask_within_radius, mask_not_same) + index1 = torch.masked_select(index1, mask) + index2 = torch.masked_select(index2, mask) + unit_cell = torch.masked_select( + unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) + + mask_num_neighbors, num_neighbors_image = get_max_neighbors_mask( + natoms=data.natoms, + index=index1, + atom_distance=atom_distance_sqr, + max_num_neighbors_threshold=max_num_neighbors_threshold, + ) + + if not torch.all(mask_num_neighbors): + # Mask out the atoms to ensure each atom has at most max_num_neighbors_threshold neighbors + index1 = torch.masked_select(index1, mask_num_neighbors) + index2 = torch.masked_select(index2, mask_num_neighbors) + unit_cell = torch.masked_select( + unit_cell.view(-1, 3), mask_num_neighbors.view(-1, 1).expand(-1, 3) + ) + unit_cell = unit_cell.view(-1, 3) + + edge_index = torch.stack((index2, index1)) + + return edge_index, unit_cell, num_neighbors_image + + +def get_max_neighbors_mask( + natoms: torch.Tensor, + index: torch.Tensor, + atom_distance: torch.Tensor, + max_num_neighbors_threshold: int, +): + """ + Give a mask that filters out edges so that each atom has at most + `max_num_neighbors_threshold` neighbors. + Assumes that `index` is sorted. + """ + device = natoms.device + num_atoms = natoms.sum() + + # Get number of neighbors + # segment_coo assumes sorted index + ones = index.new_ones(1).expand_as(index) + num_neighbors = segment_coo(ones, index, dim_size=num_atoms) + max_num_neighbors = num_neighbors.max() + num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) + + # Get number of (thresholded) neighbors per image + image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) + image_indptr[1:] = torch.cumsum(natoms, dim=0) + num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + + # If max_num_neighbors is below the threshold, return early + if ( + max_num_neighbors <= max_num_neighbors_threshold + or max_num_neighbors_threshold <= 0 + ): + mask_num_neighbors = torch.tensor([True], dtype=bool, device=device).expand_as( + index + ) + return mask_num_neighbors, num_neighbors_image + + # Create a tensor of size [num_atoms, max_num_neighbors] to sort the distances of the neighbors. + # Fill with infinity so we can easily remove unused distances later. + distance_sort = torch.full([num_atoms * max_num_neighbors], np.inf, device=device) + + # Create an index map to map distances from atom_distance to distance_sort + # index_sort_map assumes index to be sorted + index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors + index_neighbor_offset_expand = torch.repeat_interleave( + index_neighbor_offset, num_neighbors + ) + index_sort_map = ( + index * max_num_neighbors + + torch.arange(len(index), device=device) + - index_neighbor_offset_expand + ) + distance_sort.index_copy_(0, index_sort_map, atom_distance) + distance_sort = distance_sort.view(num_atoms, max_num_neighbors) + + # Sort neighboring atoms based on distance + distance_sort, index_sort = torch.sort(distance_sort, dim=1) + # Select the max_num_neighbors_threshold neighbors that are closest + distance_sort = distance_sort[:, :max_num_neighbors_threshold] + index_sort = index_sort[:, :max_num_neighbors_threshold] + + # Offset index_sort so that it indexes into index + index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand( + -1, max_num_neighbors_threshold + ) + # Remove "unused pairs" with infinite distances + mask_finite = torch.isfinite(distance_sort) + index_sort = torch.masked_select(index_sort, mask_finite) + + # At this point index_sort contains the index into index of the + # closest max_num_neighbors_threshold neighbors per atom + # Create a mask to remove all pairs not in index_sort + mask_num_neighbors = torch.zeros(len(index), device=device, dtype=bool) + mask_num_neighbors.index_fill_(0, index_sort, True) + + return mask_num_neighbors, num_neighbors_image diff --git a/src/jmp/utils/state_dict.py b/src/jmp/utils/state_dict.py new file mode 100644 index 0000000..fd0c8f0 --- /dev/null +++ b/src/jmp/utils/state_dict.py @@ -0,0 +1,103 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import fnmatch +from collections import defaultdict +from collections.abc import Mapping +from logging import getLogger + +import torch +import torch.nn as nn +from lightning_fabric.utilities import rank_zero_warn + +log = getLogger(__name__) + + +def _report_incompat_keys( + model: nn.Module, + missing_keys: list[str], + unexpected_keys: list[str], + strict: bool = True, +): + error_msgs = [] + if missing_keys: + error_msgs.insert( + 0, + "Missing key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in missing_keys) + ), + ) + if unexpected_keys: + error_msgs.insert( + 0, + "Unexpected key(s) in state_dict: {}. ".format( + ", ".join('"{}"'.format(k) for k in unexpected_keys) + ), + ) + + if len(error_msgs) > 0: + error_msg = "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(error_msgs) + ) + if strict: + raise RuntimeError(error_msg) + else: + rank_zero_warn(error_msg) + + +def load_state_dict( + module: nn.Module, + state_dict: Mapping[str, torch.Tensor], + ignored_key_patterns: list[str] | None = None, + ignored_missing_keys: list[str] | None = None, + ignored_unexpected_keys: list[str] | None = None, + strict: bool = True, +): + if ignored_key_patterns: + updated_state_dict: dict[str, torch.Tensor] = {} + matching_patterns = defaultdict[str, list[str]](lambda: []) + for k, v in state_dict.items(): + if ( + matched_pattern := next( + ( + pattern + for pattern in ignored_key_patterns + if fnmatch.fnmatch(k, pattern) + ), + None, + ) + ) is not None: + matching_patterns[matched_pattern].append(k) + continue + + updated_state_dict[k] = v + state_dict = updated_state_dict + + for pattern, matched_keys in matching_patterns.items(): + log.critical( + f"{pattern=} matched keys {matched_keys}, " + "which were ignored during loading." + ) + + missing_keys, unexpected_keys = module.load_state_dict(state_dict, strict=False) + + if ignored_key_patterns: + missing_keys = [ + k + for k in missing_keys + if not any(fnmatch.fnmatch(k, pattern) for pattern in ignored_key_patterns) + ] + if ignored_missing_keys: + missing_keys = [k for k in missing_keys if k not in ignored_missing_keys] + + if ignored_unexpected_keys: + unexpected_keys = [ + k for k in unexpected_keys if k not in ignored_unexpected_keys + ] + + _report_incompat_keys(module, missing_keys, unexpected_keys, strict=strict) diff --git a/src/submitit/__init__.py b/src/submitit/__init__.py new file mode 100644 index 0000000..8c2388c --- /dev/null +++ b/src/submitit/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" "Python 3.6+ toolbox for submitting jobs to Slurm""" + +# allow explicit reimports (mypy) by renaming all imports +from . import helpers as helpers +from .auto.auto import AutoExecutor as AutoExecutor +from .core.core import Executor as Executor +from .core.core import Job as Job +from .core.job_environment import JobEnvironment as JobEnvironment +from .local.debug import DebugExecutor as DebugExecutor +from .local.debug import DebugJob as DebugJob +from .local.local import LocalExecutor as LocalExecutor +from .local.local import LocalJob as LocalJob +from .slurm.slurm import SlurmExecutor as SlurmExecutor +from .slurm.slurm import SlurmJob as SlurmJob + +__version__ = "1.4.5" diff --git a/src/submitit/auto/__init__.py b/src/submitit/auto/__init__.py new file mode 100644 index 0000000..602d268 --- /dev/null +++ b/src/submitit/auto/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# diff --git a/src/submitit/auto/auto.py b/src/submitit/auto/auto.py new file mode 100644 index 0000000..8296af7 --- /dev/null +++ b/src/submitit/auto/auto.py @@ -0,0 +1,244 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import typing as tp +import warnings +from pathlib import Path +from typing import Any, List, Optional, Type, Union + +from ..core import plugins +from ..core.core import Executor, Job +from ..core.utils import DelayedSubmission + + +def _convert_deprecated_args( + kwargs: tp.Dict[str, Any], deprecated_args: tp.Mapping[str, str] +) -> None: + for arg in list(kwargs): + new_arg = deprecated_args.get(arg) + if not new_arg: + continue + kwargs[new_arg] = kwargs.pop(arg) + warnings.warn(f"Setting '{arg}' is deprecated. Use '{new_arg}' instead.") + + +class AutoExecutor(Executor): + """Automatic job executor + This class is used to hold the parameters to run a job either on the cluster + corresponding to the environment. + It can also be used to run job locally or in debug mode. + In practice, it will create a bash file in the specified directory for each job, + and pickle the task function and parameters. At completion, the job will also pickle + the output. Logs are also dumped in the same directory. + + Executor specific parameters must be specified by prefixing them with the name + of the executor they refer to. eg: + - 'chronos_conda_file' (internal) + - 'slurm_max_num_timeout' + See each executor documentation for the list of available parameters. + + Parameters + ---------- + folder: Path/str + folder for storing job submission/output and logs. + warn_ignored: bool + prints a warning each time a parameter is provided but ignored because it is only + useful for the other cluster. + cluster: str + Forces AutoExecutor to use the given environment. Use "local" to run jobs locally, + "debug" to run jobs in process. + kwargs: other arguments must be prefixed by the name of the executor they refer to. + {exname}_{argname}: see {argname} documentation in {Exname}Executor documentation. + + Note + ---- + - be aware that the log/output folder will be full of logs and pickled objects very fast, + it may need cleaning. + - use update_parameters to specify custom parameters (gpus_per_node etc...). If you + input erroneous parameters, an error will print all parameters available for you. + """ + + _ctor_deprecated_args = { + "max_num_timeout": "slurm_max_num_timeout", + "conda_file": "chronos_conda_file", + } + + def __init__( + self, folder: Union[str, Path], cluster: Optional[str] = None, **kwargs: Any + ) -> None: + self.cluster = cluster or self.which() + + executors = plugins.get_executors() + if self.cluster not in executors: + raise ValueError( + f"AutoExecutor doesn't know any executor named {self.cluster}" + ) + + _convert_deprecated_args(kwargs, self._ctor_deprecated_args) + err = "Extra arguments must be prefixed by executor named, received unknown arg" + err_ex_list = f"Known executors: {', '.join(executors)}." + for name in kwargs: + assert "_" in name, f"{err} '{name}'. {err_ex_list}" + prefix = name.split("_")[0] + assert ( + prefix in executors + ), f"{err} '{name}', and '{prefix}' executor is also unknown. {err_ex_list}" + self._executor = flexible_init(executors[self.cluster], folder, **kwargs) + + valid = self._valid_parameters() + self._deprecated_args = { + arg: f"{ex_name}_{arg}" + for ex_name, ex in executors.items() + for arg in ex._valid_parameters() + if arg not in valid + } + super().__init__(self._executor.folder, self._executor.parameters) + + @staticmethod + def which() -> str: + """Returns what is the detected cluster.""" + executors = plugins.get_executors() + best_ex = max(executors, key=lambda ex: executors[ex].affinity()) + + if executors[best_ex].affinity() <= 0: + raise RuntimeError( + f"Did not found an available executor among {executors.keys()}." + ) + + return best_ex + + def register_dev_folders(self, folders: List[Union[str, Path]]) -> None: + """Archive a list of folders to be untarred in the job working directory. + This is only implemented for internal cluster, for running job on non-installed packages. + This is not useful on slurm since the working directory of jobs is identical to + your work station working directory. + + folders: list of paths + The list of folders to archive and untar in the job working directory + """ + register = getattr(self._executor, "register_dev_folders", None) + if register is not None: + register(folders) + else: + # TODO this should be done through update parameters + warnings.warn( + "Ignoring dev folder registration as it is only supported (and needed) for internal cluster" + ) + + @classmethod + def _typed_parameters(cls) -> tp.Dict[str, Type]: + return { + "name": str, + "timeout_min": int, + "mem_gb": float, + "nodes": int, + "cpus_per_task": int, + "gpus_per_node": int, + "tasks_per_node": int, + "stderr_to_stdout": bool, + } + + @classmethod + def _valid_parameters(cls) -> tp.Set[str]: + return set(cls._typed_parameters().keys()) + + def _internal_update_parameters(self, **kwargs: Any) -> None: + """Updates submission parameters to srun/crun. + + Parameters + ---------- + AutoExecutors provides shared parameters that are translated for each specific cluster. + Those are: timeout_min (int), mem_gb (int), gpus_per_node (int), cpus_per_task (int), + nodes (int), tasks_per_node (int) and name (str). + Cluster specific parameters can be specified by prefixing them with the cluster name. + + Notes + ----- + - Cluster specific parameters win over shared parameters. + eg: if both `slurm_time` and `timeout_min` are provided, then: + - `slurm_time` is used on the slurm cluster + - `timeout_min` is used on other clusters + """ + # We handle None as not set. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # check type of replaced variables + generics = AutoExecutor._typed_parameters() + for name, expected_type in generics.items(): + if expected_type == float: + expected_type = (int, float) # type: ignore + if name in kwargs: + assert isinstance(kwargs[name], expected_type), ( + f'Parameter "{name}" expected type {expected_type} ' + f'(but value: "{kwargs[name]}")' + ) + + _convert_deprecated_args(kwargs, self._deprecated_args) + specific = [x.split("_", 1) for x in kwargs if x not in generics] + + invalid = [] + executors = plugins.get_executors() + for ex_arg in specific: + if len(ex_arg) != 2: + invalid.append( + f"Parameter '{ex_arg[0]}' need to be prefixed by an executor name." + ) + continue + ex, arg = ex_arg + + if ex not in executors: + invalid.append(f"Unknown executor '{ex}' in parameter '{ex}_{arg}'.") + continue + + valid = executors[ex]._valid_parameters() + if arg not in valid and arg not in generics: + invalid.append( + f"Unknown argument '{arg}' for executor '{ex}' in parameter '{ex}_{arg}'." + + " Valid arguments: " + + ", ".join(valid) + ) + continue + if invalid: + invalid.extend( + [ + f"Known executors: {', '.join(executors.keys())}", + f"As a reminder, shared/generic (non-prefixed) parameters are: {generics}.", + ] + ) + raise NameError("\n".join(invalid)) + + # add cluster specific generic overrides + kwargs.update( + **{ + arg: kwargs.pop(f"{ex}_{arg}") + for ex, arg in specific + if ex == self.cluster and arg in generics + } + ) + parameters = self._executor._convert_parameters( + {k: kwargs[k] for k in kwargs if k in generics} + ) + # update parameters in the core executor + for ex, arg in specific: + # update cluster specific non-generic arguments + if arg not in generics and ex == self.cluster: + parameters[arg] = kwargs[f"{ex}_{arg}"] + + self._executor._internal_update_parameters(**parameters) + + def _internal_process_submissions( + self, delayed_submissions: tp.List[DelayedSubmission] + ) -> tp.List[Job[tp.Any]]: + return self._executor._internal_process_submissions(delayed_submissions) + + +def flexible_init( + cls: Type[Executor], folder: Union[str, Path], **kwargs: Any +) -> Executor: + prefix = cls.name() + "_" + return cls( + folder, + **{k[len(prefix) :]: val for k, val in kwargs.items() if k.startswith(prefix)}, + ) diff --git a/src/submitit/auto/test_auto.py b/src/submitit/auto/test_auto.py new file mode 100644 index 0000000..5e4859b --- /dev/null +++ b/src/submitit/auto/test_auto.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import sys +from pathlib import Path + +import pytest + +from ..local import debug +from ..slurm import test_slurm +from . import auto + + +def test_slurm_executor(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr(debug.DebugExecutor, "_valid_parameters", lambda: {"blabla"}) + with test_slurm.mocked_slurm(): + executor = auto.AutoExecutor(folder=tmp_path) + assert executor.cluster == "slurm" + + # local_xxx parameter is ignored + executor.update_parameters(mem_gb=2, name="machin", debug_blabla="blublu") + params = executor._executor.parameters + assert params == {"mem": "2GB", "job_name": "machin"} + + # shared parameter with wrong type + with pytest.raises(AssertionError): + executor.update_parameters(mem_gb="2.0GB") # should be int + # unknown shared parameter + with pytest.raises(NameError): + executor.update_parameters(blublu=2.0) + # unknown slurm parameter + with pytest.raises(NameError): + executor.update_parameters(slurm_host_filter="blublu") + # check that error message contains all + with pytest.raises(NameError, match=r"debug_blublu.*\n.*local_num_threads"): + executor.update_parameters(debug_blublu=2.0, local_num_threads=4) + + +def test_local_executor(tmp_path: Path) -> None: + with test_slurm.mocked_slurm(): + executor = auto.AutoExecutor(folder=tmp_path, cluster="local") + assert executor.cluster == "local" + executor.update_parameters(local_cpus_per_task=2) + + +def test_executor_argument(tmp_path: Path) -> None: + with test_slurm.mocked_slurm(): + executor = auto.AutoExecutor(folder=tmp_path, slurm_max_num_timeout=22) + assert getattr(executor._executor, "max_num_timeout", None) == 22 + + # Local executor + executor = auto.AutoExecutor( + folder=tmp_path, cluster="local", slurm_max_num_timeout=22 + ) + assert getattr(executor._executor, "max_num_timeout", None) != 22 + + +def test_executor_unknown_argument(tmp_path: Path) -> None: + with test_slurm.mocked_slurm(): + with pytest.raises(TypeError): + auto.AutoExecutor(folder=tmp_path, slurm_foobar=22) + + +def test_executor_deprecated_arguments(tmp_path: Path) -> None: + with test_slurm.mocked_slurm(): + with pytest.warns(UserWarning, match="slurm_max_num_timeout"): + auto.AutoExecutor(folder=tmp_path, max_num_timeout=22) + + +def test_deprecated_argument(tmp_path: Path, monkeypatch) -> None: + monkeypatch.setattr(debug.DebugExecutor, "_valid_parameters", lambda: {"blabla"}) + with test_slurm.mocked_slurm(): + executor = auto.AutoExecutor(folder=tmp_path) + assert executor.cluster == "slurm" + + # debug 'blabla' parameter is ignored + with pytest.warns(UserWarning, match=r"blabla.*debug_blabla"): + executor.update_parameters(mem_gb=2, blabla="blublu") + + +def test_overriden_arguments(tmp_path: Path) -> None: + with test_slurm.mocked_slurm(): + slurm_ex = auto.AutoExecutor(folder=tmp_path, cluster="slurm") + + slurm_ex.update_parameters( + timeout_min=60, slurm_timeout_min=120, tasks_per_node=2, slurm_ntasks_per_node=3 + ) + slurm_params = slurm_ex._executor.parameters + # slurm use time + assert slurm_params == {"time": 120, "ntasks_per_node": 3} + + # others use timeout_min + local_ex = auto.AutoExecutor(folder=tmp_path, cluster="local") + local_ex.update_parameters(timeout_min=60, slurm_time=120) + + +def test_auto_batch_watcher(tmp_path: Path) -> None: + with test_slurm.mocked_slurm(): + executor = auto.AutoExecutor(folder=tmp_path) + with executor.batch(): + job = executor.submit(print, "hi") + assert not job.done() + + +def test_redirect_stdout_stderr(executor) -> None: + def log_to_stderr_and_stdout(): + print("hello") + print("world", file=sys.stderr) + + executor.update_parameters(stderr_to_stdout=True) + job = executor.submit(log_to_stderr_and_stdout) + job.wait() + assert job.stderr() is None + stdout = job.stdout() + assert "hello" in stdout + assert "world" in stdout + + executor.update_parameters(stderr_to_stdout=False) + job = executor.submit(log_to_stderr_and_stdout) + job.wait() + assert "world" in job.stderr() + assert "hello" in job.stdout() diff --git a/src/submitit/conftest.py b/src/submitit/conftest.py new file mode 100644 index 0000000..6277268 --- /dev/null +++ b/src/submitit/conftest.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import time +from pathlib import Path + +import pytest + +from .local.local import LocalExecutor + + +@pytest.fixture() +def executor(tmp_path: Path) -> LocalExecutor: + return LocalExecutor(tmp_path) + + +@pytest.fixture(params=["a_0", "a 0", 'a"=0"', "a'; echo foo", r"a\=0", r"a\=", "a\n0"]) +def weird_tmp_path(request, tmp_path: Path) -> Path: + return tmp_path / request.param + + +@pytest.fixture() +def fast_forward_clock(monkeypatch): + """Allows to go in the future.""" + clock_time = [time.time()] + + monkeypatch.setattr(time, "time", lambda: clock_time[0]) + + def _fast_forward(minutes: float): + clock_time[0] += minutes * 60 + + return _fast_forward diff --git a/src/submitit/core/__init__.py b/src/submitit/core/__init__.py new file mode 100644 index 0000000..602d268 --- /dev/null +++ b/src/submitit/core/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# diff --git a/src/submitit/core/_submit.py b/src/submitit/core/_submit.py new file mode 100644 index 0000000..3b3986c --- /dev/null +++ b/src/submitit/core/_submit.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +from submitit.core.submission import submitit_main + +if __name__ == "__main__": + # This script is called by Executor.submit + submitit_main() diff --git a/src/submitit/core/core.py b/src/submitit/core/core.py new file mode 100644 index 0000000..d74b005 --- /dev/null +++ b/src/submitit/core/core.py @@ -0,0 +1,973 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import abc +import asyncio +import contextlib +import subprocess +import time as _time +import typing as tp +import uuid +import warnings +from pathlib import Path + +from typing_extensions import TypedDict + +from . import logger, utils + +# R as in "Result", so yes it's covariant. +# pylint: disable=typevar-name-incorrect-variance +R = tp.TypeVar("R", covariant=True) + + +class InfoWatcher: + """An instance of this class is shared by all jobs, and is in charge of calling slurm to check status for + all jobs at once (so as not to overload it). It is also in charge of dealing with errors. + Cluster is called at 0s, 2s, 4s, 8s etc... in the begginning of jobs, then at least every delay_s (default: 60) + + Parameters + ---------- + delay_s: int + Maximum delay before each non-forced call to the cluster. + """ + + # pylint: disable=too-many-instance-attributes + + def __init__(self, delay_s: int = 60) -> None: + self._delay_s = delay_s + self._registered: tp.Set[str] = set() + self._finished: tp.Set[str] = set() + self._info_dict: tp.Dict[str, tp.Dict[str, str]] = {} + self._output = b"" # for the record + self._start_time = 0.0 + self._last_status_check = float("-inf") + self._num_calls = 0 + + def read_info(self, string: tp.Union[bytes, str]) -> tp.Dict[str, tp.Dict[str, str]]: + raise NotImplementedError + + def _make_command(self) -> tp.Optional[tp.List[str]]: + raise NotImplementedError + + def get_state(self, job_id: str, mode: str = "standard") -> str: + raise NotImplementedError + + @property + def num_calls(self) -> int: + """Number of calls to sacct""" + return self._num_calls + + def clear(self) -> None: + """Clears cache. + This should hopefully not be used. If you have to use it, please add a github issue. + """ + self._finished = set() + self._start_time = _time.time() + self._last_status_check = float("-inf") + self._info_dict = {} + self._output = b"" + + def get_info(self, job_id: str, mode: str = "standard") -> tp.Dict[str, str]: + """Returns a dict containing info about the job. + State of finished jobs are cached (use watcher.clear() to remove all cache) + + Parameters + ---------- + job_id: str + id of the job on the cluster + mode: str + one of "force" (forces a call), "standard" (calls regularly) or "cache" (does not call) + """ + if job_id is None: + raise RuntimeError("Cannot call sacct without a slurm id") + if job_id not in self._registered: + self.register_job(job_id) + # check with a call to sacct/cinfo + self.update_if_long_enough(mode) + return self._info_dict.get(job_id, {}) + + def is_done(self, job_id: str, mode: str = "standard") -> bool: + """Returns whether the job is finished. + + Parameters + ---------- + job_id: str + id of the job on the cluster + mode: str + one of "force" (forces a call), "standard" (calls regularly) or "cache" (does not call) + """ + state = self.get_state(job_id, mode=mode) + return state.upper() not in ["READY", "PENDING", "RUNNING", "UNKNOWN", "REQUEUED", "COMPLETING"] + + def update_if_long_enough(self, mode: str) -> None: + """Updates if forced to, or if the delay is reached + (Force-updates with less than 1ms delay are ignored) + Also checks for finished jobs + """ + assert mode in ["standard", "force", "cache"] + if mode == "cache": + return + last_check_delta = _time.time() - self._last_status_check + last_job_delta = _time.time() - self._start_time + refresh_delay = min(self._delay_s, max(2, last_job_delta / 2)) + if mode == "force": + refresh_delay = 0.001 + + # the following will call update at time 0s, 2s, 4, 8, 16, 32, 64, 124 (delta 60), 184 (delta 60) etc... of last added job + # (for delay_s = 60) + if last_check_delta > refresh_delay: + self.update() + + def update(self) -> None: + """Updates the info of all registered jobs with a call to sacct""" + command = self._make_command() + if command is None: + return + self._num_calls += 1 + try: + logger.get_logger().debug(f"Call #{self.num_calls} - Command {' '.join(command)}") + self._output = subprocess.check_output(command, shell=False) + except Exception as e: + logger.get_logger().warning( + f"Call #{self.num_calls} - Bypassing sacct error {e}, status may be inaccurate." + ) + else: + self._info_dict.update(self.read_info(self._output)) + self._last_status_check = _time.time() + # check for finished jobs + to_check = self._registered - self._finished + for job_id in to_check: + if self.is_done(job_id, mode="cache"): + self._finished.add(job_id) + + def register_job(self, job_id: str) -> None: + """Register a job on the instance for shared update""" + assert isinstance(job_id, str) + self._registered.add(job_id) + self._start_time = _time.time() + self._last_status_check = float("-inf") + + +# pylint: disable=too-many-public-methods +class Job(tp.Generic[R]): + """Access to a cluster job information and result. + + Parameters + ---------- + folder: Path/str + A path to the submitted job file + job_id: str + the id of the cluster job + tasks: List[int] + The ids of the tasks associated to this job. + If None, the job has only one task (with id = 0) + """ + + _cancel_command = "dummy" + _results_timeout_s = 15 + watcher = InfoWatcher() + + def __init__(self, folder: tp.Union[Path, str], job_id: str, tasks: tp.Sequence[int] = (0,)) -> None: + self._job_id = job_id + self._tasks = tuple(tasks) + self._sub_jobs: tp.Sequence["Job[R]"] = [] + self._cancel_at_deletion = False + if len(tasks) > 1: + # This is a meta-Job + self._sub_jobs = [self.__class__(folder=folder, job_id=job_id, tasks=(k,)) for k in self._tasks] + self._paths = utils.JobPaths(folder, job_id=job_id, task_id=self.task_id) + self._start_time = _time.time() + self._last_status_check = self._start_time # for the "done()" method + # register for state updates with watcher + self._register_in_watcher() + + def _register_in_watcher(self) -> None: + if not self._tasks[0]: # only register for task=0 + self.watcher.register_job(self.job_id) + + @property + def job_id(self) -> str: + return self._job_id + + @property + def paths(self) -> utils.JobPaths: + return self._paths + + @property + def num_tasks(self) -> int: + """Returns the number of tasks in the Job""" + if not self._sub_jobs: + return 1 + return len(self._sub_jobs) + + def submission(self) -> utils.DelayedSubmission: + """Returns the submitted object, with attributes `function`, `args` and `kwargs`""" + assert ( + self.paths.submitted_pickle.exists() + ), f"Cannot find job submission pickle: {self.paths.submitted_pickle}" + return utils.DelayedSubmission.load(self.paths.submitted_pickle) + + def cancel_at_deletion(self, value: bool = True) -> "Job[R]": + """Sets whether the job deletion in the python environment triggers + cancellation of the corresponding job in the cluster + By default, jobs are not cancelled unless this method is called to turn the + option on. + + Parameters + ---------- + value: bool + if True, the cluster job will be cancelled at the instance deletion, if False, it + will not. + + Returns + ------- + Job + the current job (for chaining at submission for instance: "job = executor.submit(...).cancel_at_deletion()") + """ + self._cancel_at_deletion = value + return self + + def task(self, task_id: int) -> "Job[R]": + """Returns a given sub-Job (task). + + Parameters + ---------- + task_id + The id of the task. Must be between 0 and self.num_tasks + Returns + ------- + job + The sub_job. You can call all Job methods on it (done, stdout, ...) + If the job doesn't have sub jobs, return the job itself. + """ + if not 0 <= task_id < self.num_tasks: + raise ValueError(f"task_id {task_id} must be between 0 and {self.num_tasks - 1}") + + if not self._sub_jobs: + return self + return self._sub_jobs[task_id] + + def cancel(self, check: bool = True) -> None: + """Cancels the job + + Parameters + ---------- + check: bool + whether to wait for completion and check that the command worked + """ + (subprocess.check_call if check else subprocess.call)( + [self._cancel_command, f"{self.job_id}"], shell=False + ) + + def result(self) -> R: + r = self.results() + assert not self._sub_jobs, "You should use `results()` if your job has subtasks." + return r[0] + + def results(self) -> tp.List[R]: + """Waits for and outputs the result of the submitted function + + Returns + ------- + output + the output of the submitted function. + If the job has several tasks, it will return the output of every tasks in a List + + Raises + ------ + Exception + Any exception raised by the job + """ + self.wait() + + if self._sub_jobs: + return [tp.cast(R, sub_job.result()) for sub_job in self._sub_jobs] + + outcome, result = self._get_outcome_and_result() + if outcome == "error": + job_exception = self.exception() + if job_exception is None: + raise RuntimeError("Unknown job exception") + raise job_exception # pylint: disable=raising-bad-type + return [result] + + def exception(self) -> tp.Optional[tp.Union[utils.UncompletedJobError, utils.FailedJobError]]: + """Waits for completion and returns (not raise) the + exception containing the error log of the job + + Returns + ------- + Exception/None + the exception if any was raised during the job. + If the job has several tasks, it returns the exception of the task with + smallest id that failed. + + Raises + ------ + UncompletedJobError + In case the job never completed + """ + self.wait() + + if self._sub_jobs: + all_exceptions = [sub_job.exception() for sub_job in self._sub_jobs] + # unexpected pylint issue on correct code: + exceptions = [ + e for e in all_exceptions if e is not None # pylint: disable=used-before-assignment + ] + if not exceptions: + return None + return exceptions[0] + + try: + outcome, trace = self._get_outcome_and_result() + except utils.UncompletedJobError as e: + return e + if outcome == "error": + return utils.FailedJobError( + f"Job (task={self.task_id}) failed during processing with trace:\n" + f"----------------------\n{trace}\n" + "----------------------\n" + f"You can check full logs with 'job.stderr({self.task_id})' and 'job.stdout({self.task_id})'" + f"or at paths:\n - {self.paths.stderr}\n - {self.paths.stdout}" + ) + return None + + def _get_outcome_and_result(self) -> tp.Tuple[str, tp.Any]: + """Getter for the output of the submitted function. + + Returns + ------- + outcome + the outcome of the job: either "error" or "success" + result + the output of the submitted function + + Raises + ------ + UncompletedJobError + if the job is not finished or failed outside of the job (from slurm) + """ + assert not self._sub_jobs, "This should not be called for a meta-job" + + p = self.paths.folder + timeout = self._results_timeout_s + try: + # trigger cache update: https://stackoverflow.com/questions/3112546/os-path-exists-lies/3112717 + p.chmod(p.stat().st_mode) + except PermissionError: + # chmod requires file ownership and might fail. + # Increase the timeout since we can't force cache refresh. + timeout *= 2 + # if filesystem is slow, we need to wait a bit for result_pickle. + start_wait = _time.time() + while not self.paths.result_pickle.exists() and _time.time() - start_wait < timeout: + _time.sleep(1) + if not self.paths.result_pickle.exists(): + message = [ + f"Job {self.job_id} (task: {self.task_id}) with path {self.paths.result_pickle}", + f"has not produced any output (state: {self.state})", + ] + log = self.stderr() + if log: + message.extend(["Error stream produced:", "-" * 40, log]) + elif self.paths.stdout.exists(): + log = subprocess.check_output(["tail", "-40", str(self.paths.stdout)], encoding="utf-8") + message.extend( + [f"No error stream produced. Look at stdout: {self.paths.stdout}", "-" * 40, log] + ) + else: + message.append(f"No output/error stream produced ! Check: {self.paths.stdout}") + raise utils.UncompletedJobError("\n".join(message)) + try: + output: tp.Tuple[str, tp.Any] = utils.pickle_load(self.paths.result_pickle) + except EOFError: + warnings.warn(f"EOFError on file {self.paths.result_pickle}, trying again in 2s") # will it work? + _time.sleep(2) + output = utils.pickle_load(self.paths.result_pickle) + return output + + def wait(self) -> None: + """Wait while no result find is found and the state is + either PENDING or RUNNING. + The state is checked from slurm at least every min and the result path + every second. + """ + while not self.done(): + _time.sleep(1) + + def done(self, force_check: bool = False) -> bool: + """Checks whether the job is finished. + This is done by checking if the result file is present, + or checking the job state regularly (at least every minute) + If the job has several tasks, the job is done once all tasks are done. + + Parameters + ---------- + force_check: bool + Forces the slurm state update + + Returns + ------- + bool + whether the job is finished or not + + Note + ---- + This function is not foolproof, and may say that the job is not terminated even + if it is when the job failed (no result file, but job not running) because + we avoid calling sacct/cinfo everytime done is called + """ + # TODO: keep state info once job is finished? + if self._sub_jobs: + return all(sub_job.done() for sub_job in self._sub_jobs) + p = self.paths.folder + try: + # trigger cache update: https://stackoverflow.com/questions/3112546/os-path-exists-lies/3112717 + p.chmod(p.stat().st_mode) + except OSError: + pass + if self.paths.result_pickle.exists(): + return True + # check with a call to sacct/cinfo + if self.watcher.is_done(self.job_id, mode="force" if force_check else "standard"): + return True + return False + + @property + def task_id(self) -> tp.Optional[int]: + return None if len(self._tasks) > 1 else self._tasks[0] + + @property + def state(self) -> str: + """State of the job (does not force an update)""" + return self.watcher.get_state(self.job_id, mode="standard") + + def get_info(self, mode: str = "force") -> tp.Dict[str, str]: + """Returns informations about the job as a dict (sacct call)""" + return self.watcher.get_info(self.job_id, mode=mode) + + def _get_logs_string(self, name: str) -> tp.Optional[str]: + """Returns a string with the content of the log file + or None if the file does not exist yet + + Parameter + --------- + name: str + either "stdout" or "stderr" + """ + paths = {"stdout": self.paths.stdout, "stderr": self.paths.stderr} + if name not in paths: + raise ValueError(f'Unknown "{name}", available are {list(paths.keys())}') + if not paths[name].exists(): + return None + with paths[name].open("r") as f: + string: str = f.read() + return string + + def stdout(self) -> tp.Optional[str]: + """Returns a string with the content of the print log file + or None if the file does not exist yet + """ + if self._sub_jobs: + stdout_ = [sub_job.stdout() for sub_job in self._sub_jobs] + stdout_not_none = [s for s in stdout_ if s is not None] + if not stdout_not_none: + return None + return "\n".join(stdout_not_none) + + return self._get_logs_string("stdout") + + def stderr(self) -> tp.Optional[str]: + """Returns a string with the content of the error log file + or None if the file does not exist yet + """ + if self._sub_jobs: + stderr_ = [sub_job.stderr() for sub_job in self._sub_jobs] + stderr_not_none: tp.List[str] = [s for s in stderr_ if s is not None] + if not stderr_not_none: + return None + return "\n".join(stderr_not_none) + return self._get_logs_string("stderr") + + def awaitable(self) -> "AsyncJobProxy[R]": + """Returns a proxy object that provides asyncio methods + for this Job. + """ + return AsyncJobProxy(self) + + def __repr__(self) -> str: + state = "UNKNOWN" + try: + state = self.state + except Exception as e: + logger.get_logger().warning(f"Bypassing state error:\n{e}") + return f'{self.__class__.__name__}' + + def __del__(self) -> None: + if self._cancel_at_deletion: + if not self.watcher.is_done(self.job_id, mode="cache"): + self.cancel(check=False) + + def __getstate__(self) -> tp.Dict[str, tp.Any]: + return self.__dict__ # for pickling (see __setstate__) + + def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None: + """Make sure jobs are registered when loaded from a pickle""" + self.__dict__.update(state) + self._register_in_watcher() + + +class DelayedJob(Job[R]): + """ + Represents a Job that have been queue for submission by an executor, + but hasn't yet been scheduled. + Typically obtained by calling `ex.submit` within a `ex.batch()` context + + Trying to read the attributes of the job will, by default, fail. + But if you passed `ex.batch(allow_implicit_submission=True)` then + the attribute read will in fact force the job submission, + and you'll obtain a real job instead. + """ + + def __init__(self, ex: "Executor"): + # pylint: disable = super-init-not-called + self._submitit_executor = ex + + def __getattr__(self, name: str) -> tp.Any: + # _cancel_at_deletion is used in __del__, we don't want it to trigger submission + if name == "_cancel_at_deletion": + return False + + ex = self.__dict__["_submitit_executor"] + # this submits the batch so as to fill the instance attributes + # this may return false if we try to submit within executor.batch() + # without passing `executor.batch(allow_implicit_submission=True)` + if not ex._allow_implicit_submissions: + raise AttributeError( + "Accesssing job attributes is forbidden within 'with executor.batch()' context" + ) + ex._submit_delayed_batch() + # Ensure that _promote did get called, otherwise getattr will trigger a stack overflow + assert self.__class__ != DelayedJob, f"Executor {ex} didn't properly submitted {self} !" + return getattr(self, name) + + def _promote(self, new_job: Job[tp.Any]) -> None: + # fill in the empty shell, the pickle way + self.__dict__.pop("_submitit_executor", None) + self.__dict__.update(new_job.__dict__) + # pylint: disable=attribute-defined-outside-init + self.__class__ = new_job.__class__ # type: ignore + + def __repr__(self) -> str: + return object.__repr__(self) + + +class AsyncJobProxy(tp.Generic[R]): + def __init__(self, job: Job[R]): + self.job = job + + async def wait(self, poll_interval: tp.Union[int, float] = 1) -> None: + """same as wait() but with asyncio sleep.""" + while not self.job.done(): + await asyncio.sleep(poll_interval) + + async def result(self, poll_interval: tp.Union[int, float] = 1) -> R: + """asyncio version of the result() method. + Wait asynchornously for the result to be available by polling the self.done() method. + Parameters + ---------- + poll_interval: int or float + how often to check if the result is available, in seconds + """ + await self.wait(poll_interval) + return self.job.result() + + async def results(self, poll_interval: tp.Union[int, float] = 1) -> tp.List[R]: + """asyncio version of the results() method. + + Waits asynchornously for ALL the results to be available by polling the self.done() method. + + Parameters + ---------- + poll_interval: int or float + how often to check if the result is available, in seconds + """ + await self.wait(poll_interval) + # results are ready now + return self.job.results() + + def results_as_completed(self, poll_interval: tp.Union[int, float] = 1) -> tp.Iterator[asyncio.Future]: + """awaits for all tasks results concurrently. Note that the order of results is not guaranteed to match the order + of the tasks anymore as the earliest task coming back might not be the first one you sent. + + Returns + ------- + an iterable of Awaitables that can be awaited on to get the earliest result available of the remaining tasks. + + Parameters + ---------- + poll_interval: int or float + how often to check if the result is available, in seconds + + (see https://docs.python.org/3/library/asyncio-task.html#asyncio.as_completed) + """ + if self.job.num_tasks > 1: + yield from asyncio.as_completed( + [self.job.task(i).awaitable().result(poll_interval) for i in range(self.job.num_tasks)] + ) + + # there is only one result anyway, let's just use async result + yield asyncio.ensure_future(self.result()) + + +_MSG = ( + "Interactions with jobs are not allowed within " + '"with executor.batch()" context (submissions/creations only happens at exit time).' +) + + +class EquivalenceDict(TypedDict): + """Gives the specific name of the params shared across all plugins.""" + + # Note that all values are typed as string, even though they correspond to integer. + # This allow to have a static typing on the "_equivalence_dict" method implemented + # by plugins. + # We could chose to put the proper types, but that wouldn't be enough to typecheck + # the calls to `update_parameters` which uses kwargs. + name: str + timeout_min: str + mem_gb: str + nodes: str + cpus_per_task: str + gpus_per_node: str + tasks_per_node: str + + +class Executor(abc.ABC): + """Base job executor. + + Parameters + ---------- + folder: Path/str + folder for storing job submission/output and logs. + """ + + job_class: tp.Type[Job[tp.Any]] = Job + + def __init__(self, folder: tp.Union[str, Path], parameters: tp.Optional[tp.Dict[str, tp.Any]] = None): + self.folder = Path(folder).expanduser().absolute() + self.parameters = {} if parameters is None else parameters + # storage for the batch context manager, for batch submissions: + self._delayed_batch: tp.Optional[tp.List[tp.Tuple[Job[tp.Any], utils.DelayedSubmission]]] = None + self._allow_implicit_submissions = False + + @classmethod + def name(cls) -> str: + n = cls.__name__ + if n.endswith("Executor"): + n = n[: -len("Executor")] + return n.lower() + + @contextlib.contextmanager + def batch(self, allow_implicit_submissions: bool = False) -> tp.Iterator[None]: + """Creates a context within which all submissions are packed into a job array. + By default the array submissions happens when leaving the context + + Parameter + --------- + allow_implicit_submissions: bool + submits the current batch whenever a job attribute is accessed instead of raising an exception + + Example + ------- + jobs = [] + with executor.batch(): + for k in range(12): + jobs.append(executor.submit(add, k, 1)) + + Raises + ------ + AttributeError + if trying to access a job instance attribute while the batch is not exited, and + intermediate submissions are not allowed. + """ + self._allow_implicit_submissions = allow_implicit_submissions + if self._delayed_batch is not None: + raise RuntimeError('Nesting "with executor.batch()" contexts is not allowed.') + self._delayed_batch = [] + try: + yield None + except Exception as e: + logger.get_logger().error( + 'Caught error within "with executor.batch()" context, submissions are dropped.\n ' + ) + raise e + else: + self._submit_delayed_batch() + finally: + self._delayed_batch = None + + def _submit_delayed_batch(self) -> None: + assert self._delayed_batch is not None + if not self._delayed_batch: + if not self._allow_implicit_submissions: + warnings.warn( + 'No submission happened during "with executor.batch()" context.', category=RuntimeWarning + ) + return + jobs, submissions = zip(*self._delayed_batch) + new_jobs = self._internal_process_submissions(submissions) + for j, new_j in zip(jobs, new_jobs): + j._promote(new_j) + self._delayed_batch = [] + + def submit(self, fn: tp.Callable[..., R], *args: tp.Any, **kwargs: tp.Any) -> Job[R]: + ds = utils.DelayedSubmission(fn, *args, **kwargs) + if self._delayed_batch is not None: + job: Job[R] = DelayedJob(self) + self._delayed_batch.append((job, ds)) + else: + job = self._internal_process_submissions([ds])[0] + if type(job) is Job: # pylint: disable=unidiomatic-typecheck + raise RuntimeError("Executors should never return a base Job class (implementation issue)") + return job + + @abc.abstractmethod + def _internal_process_submissions( + self, delayed_submissions: tp.List[utils.DelayedSubmission] + ) -> tp.List[Job[tp.Any]]: + ... + + def map_array(self, fn: tp.Callable[..., R], *iterable: tp.Iterable[tp.Any]) -> tp.List[Job[R]]: + """A distributed equivalent of the map() built-in function + + Parameters + ---------- + fn: callable + function to compute + *iterable: Iterable + lists of arguments that are passed as arguments to fn. + + Returns + ------- + List[Job] + A list of Job instances. + + Example + ------- + a = [1, 2, 3] + b = [10, 20, 30] + executor.map_array(add, a, b) + # jobs will compute 1 + 10, 2 + 20, 3 + 30 + """ + submissions = [utils.DelayedSubmission(fn, *args) for args in zip(*iterable)] + if len(submissions) == 0: + warnings.warn("Received an empty job array") + return [] + return self._internal_process_submissions(submissions) + + def submit_array(self, fns: tp.Sequence[tp.Callable[[], R]]) -> tp.List[Job[R]]: + """Submit a list of job. This is useful when submiting different Checkpointable functions. + Be mindful that all those functions will be run with the same requirements + (cpus, gpus, timeout, ...). So try to make group of similar function calls. + + Parameters + ---------- + fns: list of callable + functions to compute. Those functions must not need any argument. + Tyically those are "Checkpointable" instance whose arguments + have been specified in the constructor, or partial functions. + + Returns + ------- + List[Job] + A list of Job instances. + + Example + ------- + a_vals = [1, 2, 3] + b_vals = [10, 20, 30] + fns = [functools.partial(int.__add__, a, b) for (a, b) in zip (a_vals, b_vals)] + executor.submit_array(fns) + # jobs will compute 1 + 10, 2 + 20, 3 + 30 + """ + submissions = [utils.DelayedSubmission(fn) for fn in fns] + if len(submissions) == 0: + warnings.warn("Received an empty job array") + return [] + return self._internal_process_submissions(submissions) + + def update_parameters(self, **kwargs: tp.Any) -> None: + """Update submision parameters.""" + if self._delayed_batch is not None: + raise RuntimeError( + 'Changing parameters within batch context "with executor.batch():" is not allowed' + ) + self._internal_update_parameters(**kwargs) + + @classmethod + def _equivalence_dict(cls) -> tp.Optional[EquivalenceDict]: + return None + + @classmethod + def _valid_parameters(cls) -> tp.Set[str]: + """Parameters that can be set through update_parameters""" + return set() + + def _convert_parameters(self, params: tp.Dict[str, tp.Any]) -> tp.Dict[str, tp.Any]: + """Convert generic parameters to their specific equivalent. + This has to be called **before** calling `update_parameters`. + + The default implementation only renames the key using `_equivalence_dict`. + """ + eq_dict = tp.cast(tp.Optional[tp.Dict[str, str]], self._equivalence_dict()) + if eq_dict is None: + return params + return {eq_dict.get(k, k): v for k, v in params.items()} + + def _internal_update_parameters(self, **kwargs: tp.Any) -> None: + """Update submission parameters.""" + self.parameters.update(kwargs) + + @classmethod + def affinity(cls) -> int: + """The 'score' of this executor on the current environment. + + -> -1 means unavailable + -> 0 means available but won't be started unless asked (eg debug executor) + -> 1 means available + -> 2 means available and is a highly scalable executor (cluster) + """ + return 1 + + +class PicklingExecutor(Executor): + """Base job executor. + + Parameters + ---------- + folder: Path/str + folder for storing job submission/output and logs. + """ + + def __init__(self, folder: tp.Union[Path, str], max_num_timeout: int = 3) -> None: + super().__init__(folder) + self.max_num_timeout = max_num_timeout + self._throttling = 0.2 + self._last_job_submitted = 0.0 + + def _internal_process_submissions( + self, delayed_submissions: tp.List[utils.DelayedSubmission] + ) -> tp.List[Job[tp.Any]]: + """Submits a task to the cluster. + + Parameters + ---------- + fn: callable + The function to compute + *args: any positional argument for the function + **kwargs: any named argument for the function + + Returns + ------- + Job + A Job instance, providing access to the job information, + including the output of the function once it is computed. + """ + eq_dict = self._equivalence_dict() + timeout_min = self.parameters.get(eq_dict["timeout_min"] if eq_dict else "timeout_min", 5) + jobs = [] + for delayed in delayed_submissions: + tmp_uuid = uuid.uuid4().hex + pickle_path = utils.JobPaths.get_first_id_independent_folder(self.folder) / f"{tmp_uuid}.pkl" + pickle_path.parent.mkdir(parents=True, exist_ok=True) + delayed.set_timeout(timeout_min, self.max_num_timeout) + delayed.dump(pickle_path) + + self._throttle() + self._last_job_submitted = _time.time() + job = self._submit_command(self._submitit_command_str) + job.paths.move_temporary_file(pickle_path, "submitted_pickle") + jobs.append(job) + return jobs + + def _throttle(self) -> None: + while _time.time() - self._last_job_submitted < self._throttling: + _time.sleep(self._throttling) + + @property + def _submitit_command_str(self) -> str: + # this is the command submitted from "submit" to "_submit_command" + return "dummy" + + def _submit_command(self, command: str) -> Job[tp.Any]: + """Submits a command to the cluster + It is recommended not to use this function since the Job instance assumes pickle + files will be created at the end of the job, and hence it will not work correctly. + You may use a CommandFunction as argument to the submit function instead. The only + problem with this latter solution is that stdout is buffered, and you will therefore + not be able to monitor the logs in real time. + + Parameters + ---------- + command: str + a command string + + Returns + ------- + Job + A Job instance, providing access to the crun job information. + Since it has no output, some methods will not be efficient + """ + tmp_uuid = uuid.uuid4().hex + submission_file_path = ( + utils.JobPaths.get_first_id_independent_folder(self.folder) / f"submission_file_{tmp_uuid}.sh" + ) + with submission_file_path.open("w") as f: + f.write(self._make_submission_file_text(command, tmp_uuid)) + command_list = self._make_submission_command(submission_file_path) + # run + output = utils.CommandFunction(command_list, verbose=False)() # explicit errors + job_id = self._get_job_id_from_submission_command(output) + tasks_ids = list(range(self._num_tasks())) + job: Job[tp.Any] = self.job_class(folder=self.folder, job_id=job_id, tasks=tasks_ids) + job.paths.move_temporary_file(submission_file_path, "submission_file") + self._write_job_id(job.job_id, tmp_uuid) + self._set_job_permissions(job.paths.folder) + return job + + def _write_job_id(self, job_id: str, uid: str) -> None: + """Write the job id in a file named {job-independent folder}/parent_job_id_{uid}. + This can create files read by plugins to get the job_id of the parent job + """ + + @abc.abstractmethod + def _num_tasks(self) -> int: + """Returns the number of tasks associated to the job""" + raise NotImplementedError + + @abc.abstractmethod + def _make_submission_file_text(self, command: str, uid: str) -> str: + """Creates the text of a file which will be created and run + for the submission (for slurm, this is sbatch file). + """ + raise NotImplementedError + + @abc.abstractmethod + def _make_submission_command(self, submission_file_path: Path) -> tp.List[str]: + """Create the submission command.""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def _get_job_id_from_submission_command(string: tp.Union[bytes, str]) -> str: + """Recover the job id from the output of the submission command.""" + raise NotImplementedError + + @staticmethod + def _set_job_permissions(folder: Path) -> None: + pass diff --git a/src/submitit/core/job_environment.py b/src/submitit/core/job_environment.py new file mode 100644 index 0000000..f56f6a1 --- /dev/null +++ b/src/submitit/core/job_environment.py @@ -0,0 +1,269 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import signal +import socket +import sys +import time +import types +from pathlib import Path +from typing import Any, ClassVar, Dict, Optional, Sequence + +from . import logger, utils +from .utils import DelayedSubmission, JobPaths + +_PREEMPT_SIG_ENV = "SUBMITIT_PREEMPT_SIGNAL" + + +class JobEnvironment: + """Describe the environment inside which the job is running. + This includes job id, number of GPUs available, ... + + This class can only be instantiated from a running submitit job. + + @plugin-dev: default implementation look for information into environment variables. + Override _env to map environment variable to each property. + """ + + # preemption signal uses USR2 as default, but this behavior + # can be overiden (eg: export SUBMITIT_PREEMPT_SIGNAL=USR2) + # CAUTION: NCCL may catch USR1 so it should be avoided + USR_SIG = os.environ.get(_PREEMPT_SIG_ENV, "USR1") + _env: ClassVar[Dict[str, str]] = {} + + def __new__(cls, *args: Any) -> "JobEnvironment": + if cls is not JobEnvironment: + return super().__new__(cls, *args) # type: ignore + + from . import plugins # pylint: disable=cyclic-import,import-outside-toplevel + + return plugins.get_job_environment() + + def __init__(self) -> None: + self.cluster = self.name() + + @classmethod + def name(cls) -> str: + n = cls.__name__ + if n.endswith("JobEnvironment"): + n = n[: -len("JobEnvironment")] + return n.lower() + + @property + def paths(self) -> JobPaths: + """Provides the paths used by submitit, including + stdout, stderr, submitted_pickle and folder. + """ + folder = os.environ["SUBMITIT_FOLDER"] + return JobPaths(folder, job_id=self.job_id, task_id=self.global_rank) + + def activated(self) -> bool: + """Tests if we are running inside this environment. + + @plugin-dev: assumes that the SUBMITIT_EXECUTOR variable has been + set to the executor name + """ + return os.environ.get("SUBMITIT_EXECUTOR", "") == self.name() + + @property + def hostname(self) -> str: + return socket.gethostname() + + @property + def hostnames(self) -> Sequence[str]: + return [self.hostname] + + @property + def job_id(self) -> str: + if self.array_job_id: + return f"{self.array_job_id}_{self.array_task_id}" + else: + return self.raw_job_id + + @property + def raw_job_id(self) -> str: + return os.environ[self._env["job_id"]] + + @property + def array_job_id(self) -> Optional[str]: + n = "array_job_id" + return None if n not in self._env else os.environ.get(self._env[n], None) + + @property + def array_task_id(self) -> Optional[str]: + n = "array_task_id" + return None if n not in self._env else os.environ.get(self._env[n], None) + + @property + def num_tasks(self) -> int: + """Total number of tasks for the job""" + return int(os.environ.get(self._env["num_tasks"], 1)) + + @property + def num_nodes(self) -> int: + """Total number of nodes for the job""" + return int(os.environ.get(self._env["num_nodes"], 1)) + + @property + def node(self) -> int: + """Id of the current node""" + return int(os.environ.get(self._env["node"], 0)) + + @property + def global_rank(self) -> int: + """Global rank of the task""" + return int(os.environ.get(self._env["global_rank"], 0)) + + @property + def local_rank(self) -> int: + """Local rank of the task, ie on the current node.""" + return int(os.environ.get(self._env["local_rank"], 0)) + + def __repr__(self) -> str: + # should look like this: + # JobEnvironment(job_id=17015819, hostname=learnfair0218, local_rank=2(3), node=1(2), global_rank=5(6)) + info = [f"{n}={getattr(self, n)}" for n in ("job_id", "hostname")] + names = ("local_rank", "node", "global_rank") + totals = [self.num_tasks // self.num_nodes, self.num_nodes, self.num_tasks] + info += [f"{n}={getattr(self, n)}({t})" for n, t in zip(names, totals)] + info_str = ", ".join(info) + return f"JobEnvironment({info_str})" + + @classmethod + def _usr_sig(cls) -> Any: + name = "SIG" + cls.USR_SIG + out = getattr(signal, name, None) + if out is None: + raise RuntimeError( + f"Unknown signal {name}, you may need to unset or update env var {_PREEMPT_SIG_ENV} (Eg: USR2)" + ) + return out + + def _handle_signals(self, paths: JobPaths, submission: DelayedSubmission) -> None: + """Set up signals handler for the current executable. + + The default implementation checkpoint the given submission and requeues it. + @plugin-dev: Should be adapted to the signals used in this cluster. + """ + handler = SignalHandler(self, paths, submission) + signal.signal(self._usr_sig(), handler.checkpoint_and_try_requeue) + # A priori we don't need other signals anymore, + # but still log them to make it easier to debug. + signal.signal(signal.SIGTERM, handler.bypass) + signal.signal(signal.SIGCONT, handler.bypass) + + # pylint: disable=unused-argument + def _requeue(self, countdown: int) -> None: + """Requeue the current job. + + @plugin-dev:Must be overridden by JobEnvironment implementations. + Use self.job_id to find what need to be requeued. + """ + + +class SignalHandler: + def __init__(self, env: JobEnvironment, job_paths: JobPaths, delayed: DelayedSubmission) -> None: + self.env = env + self._job_paths = job_paths + self._delayed = delayed + self._logger = logger.get_logger() + self._start_time = time.time() + + def has_timed_out(self) -> bool: + # SignalHandler is created by submitit as soon as the process start, + # so _start_time is an accurate measure of the global runtime of the job. + walltime = time.time() - self._start_time + max_walltime = self._delayed._timeout_min * 60 + guaranteed_walltime = min(max_walltime * 0.8, max_walltime - 10 * 60) + + timed_out = walltime >= guaranteed_walltime + if timed_out: + self._logger.info( + f"Job has timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes." + ) + else: + self._logger.info( + f"Job has not timed out. Ran {walltime / 60:.0f} minutes out of requested {max_walltime / 60:.0f} minutes." + ) + return timed_out + + def bypass(self, signum: int, frame: types.FrameType = None) -> None: # pylint:disable=unused-argument + self._logger.warning(f"Bypassing signal {signal.Signals(signum).name}") + + def checkpoint_and_try_requeue( + self, signum: int, frame: types.FrameType = None # pylint:disable=unused-argument + ) -> None: + timed_out = self.has_timed_out() + case = "timed-out" if timed_out else "preempted" + self._logger.warning( + f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}: this job is {case}." + ) + + procid = self.env.global_rank + if procid != 0: + self._logger.info(f"Not checkpointing nor requeuing since I am a slave (procid={procid}).") + # do not sys.exit, because it might kill the master task + return + + delayed = self._delayed + countdown = delayed._timeout_countdown - timed_out + no_requeue_reason = "" + if hasattr(delayed.function, "checkpoint"): + no_requeue_reason = _checkpoint(delayed, self._job_paths.submitted_pickle, countdown) + elif timed_out: + no_requeue_reason = "timed-out and not checkpointable" + if countdown < 0: # this is the end + no_requeue_reason = "timed-out too many times" + if no_requeue_reason: + # raise an error so as to create "result_pickle" file which notifies the job is over + # this is caught by the try/except in "process_job" + message = f"Job not requeued because: {no_requeue_reason}." + self._logger.info(message) + raise utils.UncompletedJobError(message) + # if everything went well, requeue! + self.env._requeue(countdown) + self._exit() + + def checkpoint_and_exit( + self, signum: int, frame: types.FrameType = None # pylint:disable=unused-argument + ) -> None: + # Note: no signal is actually bound to `checkpoint_and_exit` but this is used by plugins. + self._logger.info(f"Caught signal {signal.Signals(signum).name} on {socket.gethostname()}") + + procid = self.env.global_rank + if procid: + self._logger.info(f"Not checkpointing since I am a slave (procid={procid}).") + # do not sys.exit, because it might kill the master task + return + + delayed = self._delayed + if hasattr(delayed.function, "checkpoint"): + _checkpoint(self._delayed, self._job_paths.submitted_pickle, self._delayed._timeout_countdown) + self._exit() + + def _exit(self) -> None: + # extracted for mocking + self._logger.info("Exiting gracefully after preemption/timeout.") + sys.exit(-1) + + +def _checkpoint(delayed: DelayedSubmission, filepath: Path, countdown: int) -> str: + """Call the checkpoint method and dump the updated delayed. + + Returns: + -------- + no_requeue_reason: str + a string explaining while there was no requeuing, else empty string if requeuing works + """ + logger.get_logger().info("Calling checkpoint method.") + ckpt_delayed = delayed._checkpoint_function() + if ckpt_delayed is None: + return "checkpoint function returned None" + ckpt_delayed.set_timeout(delayed._timeout_min, countdown) + with utils.temporary_save_path(filepath) as tmp: + ckpt_delayed.dump(tmp) + return "" # requeues diff --git a/src/submitit/core/logger.py b/src/submitit/core/logger.py new file mode 100644 index 0000000..59e1d70 --- /dev/null +++ b/src/submitit/core/logger.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging.config +import os +from typing import Union + +# provide a way to change level through SUBMITIT_LOG_LEVEL environment variable: +# level "CRITICAL" (50) or more (eg.: "100") will deactivate submitit logger +# "NOCONFIG" will avoid configuration +LOG_VARNAME = "SUBMITIT_LOG_LEVEL" +level_str = os.environ.get(LOG_VARNAME, "INFO").upper() +level: Union[int, str] = level_str if not level_str.isdigit() else int(level_str) + + +CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": {"submitit_basic": {"format": "%(name)s %(levelname)s (%(asctime)s) - %(message)s"}}, + "handlers": { + "submitit_out": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "submitit_basic", + "stream": "ext://sys.stdout", + }, + "submitit_err": { + "class": "logging.StreamHandler", + "level": "WARNING", + "formatter": "submitit_basic", + "stream": "ext://sys.stderr", + }, + }, + "loggers": {"submitit": {"handlers": ["submitit_err", "submitit_out"], "level": level}}, +} + + +if level != "NOCONFIG": + logging.config.dictConfig(CONFIG) + + +def get_logger() -> logging.Logger: + return logging.getLogger("submitit") + + +def exception(*args: str) -> None: + get_logger().exception(*args) + + +def warning(*args: str) -> None: + get_logger().warning(*args) diff --git a/src/submitit/core/plugins.py b/src/submitit/core/plugins.py new file mode 100644 index 0000000..70c7463 --- /dev/null +++ b/src/submitit/core/plugins.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import functools +import os +from typing import TYPE_CHECKING, List, Mapping, Tuple, Type + +from ..core import logger + +if TYPE_CHECKING: + # Breaks the import cycle + from ..core.core import Executor + from ..core.job_environment import JobEnvironment + + +@functools.lru_cache() +def _get_plugins() -> Tuple[List[Type["Executor"]], List["JobEnvironment"]]: + # pylint: disable=cyclic-import,import-outside-toplevel + # Load dynamically to avoid import cycle + # pkg_resources goes through all modules on import. + import pkg_resources + + from ..local import debug, local + from ..slurm import slurm + + # TODO: use sys.modules.keys() and importlib.resources to find the files + # We load both kind of entry points at the same time because we have to go through all module files anyway. + executors: List[Type["Executor"]] = [slurm.SlurmExecutor, local.LocalExecutor, debug.DebugExecutor] + job_envs = [slurm.SlurmJobEnvironment(), local.LocalJobEnvironment(), debug.DebugJobEnvironment()] + for entry_point in pkg_resources.iter_entry_points("submitit"): + if entry_point.name not in ("executor", "job_environment"): + logger.warning(f"Found unknown entry point in package {entry_point.module_name}: {entry_point}") + continue + + try: + # call `load` rather than `resolve`. + # `load` also checks the module and its dependencies are correctly installed. + cls = entry_point.load() + except Exception as e: + # This may happen if the plugin haven't been correctly installed + logger.exception(f"Failed to load submitit plugin '{entry_point.module_name}': {e}") + continue + + if entry_point.name == "executor": + executors.append(cls) + else: + try: + job_env = cls() + except Exception as e: + logger.exception( + f"Failed to init JobEnvironment '{cls.name}' ({cls}) from submitit plugin '{entry_point.module_name}': {e}" + ) + continue + job_envs.append(job_env) + + return (executors, job_envs) + + +@functools.lru_cache() +def get_executors() -> Mapping[str, Type["Executor"]]: + # TODO: check collisions between executor names + return {ex.name(): ex for ex in _get_plugins()[0]} + + +def get_job_environment() -> "JobEnvironment": + # Don't cache this function. It makes testing harder. + # The slow part is the plugin discovery anyway. + envs = get_job_environments() + # bypassing can be helful for testing + if "_TEST_CLUSTER_" in os.environ: + c = os.environ["_TEST_CLUSTER_"] + assert c in envs, f"Unknown $_TEST_CLUSTER_='{c}', available: {envs.keys()}." + return envs[c] + for env in envs.values(): + # TODO? handle the case where several envs are valid + if env.activated(): + return env + raise RuntimeError( + f"Could not figure out which environment the job is runnning in. Known environments: {', '.join(envs.keys())}." + ) + + +@functools.lru_cache() +def get_job_environments() -> Mapping[str, "JobEnvironment"]: + return {env.name(): env for env in _get_plugins()[1]} diff --git a/src/submitit/core/submission.py b/src/submitit/core/submission.py new file mode 100644 index 0000000..8d85286 --- /dev/null +++ b/src/submitit/core/submission.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import argparse +import os +import time +import traceback +from pathlib import Path +from typing import Union + +try: # loading numpy before loading the pickle, to avoid unexpected interactions + # pylint: disable=unused-import + import numpy # type: ignore # noqa +except ImportError: + pass + +from . import job_environment, utils +from .logger import get_logger + + +def process_job(folder: Union[Path, str]) -> None: + """Loads a pickled job, runs it and pickles the output + + Parameter + --------- + folder: Path/str + path of the folder where the job pickle will be stored (with a name containing its uuid) + + Side-effect + ----------- + Creates a picked output file next to the job file. + """ + os.environ["SUBMITIT_FOLDER"] = str(folder) + env = job_environment.JobEnvironment() + paths = env.paths + logger = get_logger() + logger.info(f"Starting with {env}") + logger.info(f"Loading pickle: {paths.submitted_pickle}") + wait_time = 60 + for _ in range(wait_time): + if paths.submitted_pickle.exists(): + break + time.sleep(1) + if not paths.submitted_pickle.exists(): + raise RuntimeError( + f"Waited for {wait_time} seconds but could not find submitted jobs in path:\n{paths.submitted_pickle}" + ) + try: + delayed = utils.DelayedSubmission.load(paths.submitted_pickle) + env = job_environment.JobEnvironment() + handle_signals = bool(int(os.environ.get("SUBMITIT_HANDLE_SIGNALS", "0"))) + if handle_signals: + logger.info("Handling signals (SUBMITIT_HANDLE_SIGNALS=1)") + env._handle_signals(paths, delayed) + else: + logger.info("Ignoring signal handling (SUBMITIT_HANDLE_SIGNALS=0)") + result = delayed.result() + logger.info("Job completed successfully") + del delayed # if it blocks here, you have a race condition that must be solved! + with utils.temporary_save_path( + paths.result_pickle + ) as tmppath: # save somewhere else, and move + utils.cloudpickle_dump(("success", result), tmppath) + del result + logger.info("Exitting after successful completion") + except ( + Exception + ) as error: # TODO: check pickle methods for capturing traceback; pickling and raising + try: + with utils.temporary_save_path(paths.result_pickle) as tmppath: + utils.cloudpickle_dump(("error", traceback.format_exc()), tmppath) + except Exception as dumperror: + logger.error(f"Could not dump error:\n{error}\n\nbecause of {dumperror}") + logger.error("Submitted job triggered an exception") + raise error + + +def submitit_main() -> None: + parser = argparse.ArgumentParser(description="Run a job") + parser.add_argument( + "folder", type=str, help="Folder where the jobs are stored (in subfolder)" + ) + args = parser.parse_args() + process_job(args.folder) diff --git a/src/submitit/core/test_async.py b/src/submitit/core/test_async.py new file mode 100644 index 0000000..2af0a79 --- /dev/null +++ b/src/submitit/core/test_async.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +from pathlib import Path + +import pytest + +from . import submission, utils +from .test_core import FakeExecutor, _three_time + + +@pytest.mark.asyncio +async def test_result(tmp_path: Path, event_loop): + executor = FakeExecutor(folder=tmp_path) + job = executor.submit(_three_time, 8) + result_task = event_loop.create_task(job.awaitable().result()) + with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): + submission.process_job(folder=job.paths.folder) + result = await result_task + assert result == 24 + + +@pytest.mark.asyncio +async def test_results_single(tmp_path: Path, event_loop): + executor = FakeExecutor(folder=tmp_path) + job = executor.submit(_three_time, 8) + result_task = event_loop.create_task(job.awaitable().results()) + with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): + submission.process_job(folder=job.paths.folder) + result = await result_task + assert result == [24] + + +@pytest.mark.asyncio +async def test_results_ascompleted_single(tmp_path: Path): + executor = FakeExecutor(folder=tmp_path) + job = executor.submit(_three_time, 8) + with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): + submission.process_job(folder=job.paths.folder) + count = 0 + for aws in job.awaitable().results_as_completed(): + result = await aws + count += 1 + assert result == 24 + assert count == 1 diff --git a/src/submitit/core/test_core.py b/src/submitit/core/test_core.py new file mode 100644 index 0000000..9942523 --- /dev/null +++ b/src/submitit/core/test_core.py @@ -0,0 +1,260 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +# pylint: disable=redefined-outer-name +import contextlib +import pickle +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Sequence, Union +from unittest.mock import patch + +import pytest + +from . import core, submission, utils + + +# pylint: disable=no-self-use +class MockedSubprocess: + """Helper for mocking subprocess calls""" + + SACCT_HEADER = "JobID|State" + SACCT_JOB = "{j}|{state}\n{j}.ext+|{state}\n{j}.0|{state}" + + def __init__(self, known_cmds: Sequence[str] = None) -> None: + self.job_sacct: Dict[str, str] = {} + self.last_job: str = "" + self._subprocess_check_output = subprocess.check_output + self.known_cmds = known_cmds or [] + self.job_count = 12 + + def __call__(self, command: Sequence[str], **kwargs: Any) -> bytes: + program = command[0] + if program in ["sacct", "sbatch", "scancel"]: + return getattr(self, program)(command[1:]).encode() + elif program == "tail": + return self._subprocess_check_output(command, **kwargs) + else: + raise ValueError(f'Unknown command to mock "{command}".') + + def sacct(self, _: Sequence[str]) -> str: + return "\n".join(self.job_sacct.values()) + + def sbatch(self, args: Sequence[str]) -> str: + """Create a "RUNNING" job.""" + job_id = str(self.job_count) + self.job_count += 1 + sbatch_file = Path(args[0]) + array = 0 + if sbatch_file.exists(): + array_lines = [l for l in sbatch_file.read_text().splitlines() if "--array" in l] + if array_lines: + # SBATCH --array=0-4%3 + array = int(array_lines[0].split("=0-")[-1].split("%")[0]) + array += 1 + self.set_job_state(job_id, "RUNNING", array) + return f"Running job {job_id}\n" + + def scancel(self, _: Sequence[str]) -> str: + # TODO:should we call set_job_state ? + return "" + + def set_job_state(self, job_id: str, state: str, array: int = 0) -> None: + self.job_sacct[job_id] = self._sacct(state, job_id, array) + self.last_job = job_id + + def _sacct(self, state: str, job_id: str, array: int) -> str: + if array == 0: + lines = self.SACCT_JOB.format(j=job_id, state=state) + else: + lines = "\n".join(self.SACCT_JOB.format(j=f"{job_id}_{i}", state=state) for i in range(array)) + return "\n".join((self.SACCT_HEADER, lines)) + + def which(self, name: str) -> Optional[str]: + return "here" if name in self.known_cmds else None + + def mock_cmd_fn(self, *args, **_): + # CommandFunction(cmd)() ~= subprocess.check_output(cmd) + return lambda: self(*args) + + @contextlib.contextmanager + def context(self) -> Iterator[None]: + with patch("submitit.core.utils.CommandFunction", new=self.mock_cmd_fn): + with patch("subprocess.check_output", new=self): + with patch("shutil.which", new=self.which): + with patch("subprocess.check_call", new=self): + yield None + + @contextlib.contextmanager + def job_context(self, job_id: str) -> Iterator[None]: + with utils.environment_variables( + _USELESS_TEST_ENV_VAR_="1", SUBMITIT_EXECUTOR="slurm", SLURM_JOB_ID=str(job_id) + ): + yield None + + +class FakeInfoWatcher(core.InfoWatcher): + + # pylint: disable=abstract-method + def get_state(self, job_id: str, mode: str = "standard") -> str: + return "running" + + +class FakeJob(core.Job[core.R]): + + watcher = FakeInfoWatcher() + _cancel_at_deletion = False + + +class FakeExecutor(core.PicklingExecutor): + + job_class = FakeJob + + @property + def _submitit_command_str(self) -> str: + return "echo 1" + + def _num_tasks(self) -> int: + return 1 + + def _make_submission_file_text(self, command: str, uid: str) -> str: # pylint: disable=unused-argument + """Creates the text of a file which will be created and run + for the submission (for slurm, this is sbatch file). + """ + return command + "2" # this makes "echo 12" + + def _make_submission_command(self, submission_file_path: Path) -> List[str]: + """Create the submission command.""" + with submission_file_path.open("r") as f: + text: str = f.read() + return text.split() # this makes ["echo", "12"] + + @staticmethod + def _get_job_id_from_submission_command(string: Union[bytes, str]) -> str: + return string if isinstance(string, str) else string.decode() # this returns "12" + + +def _three_time(x: int) -> int: + return 3 * x + + +def do_nothing(*args: Any, **kwargs: Any) -> int: + print("my args", args, flush=True) + print("my kwargs", kwargs, flush=True) + if "sleep" in kwargs: + print("Waiting", flush=True) + time.sleep(int(kwargs["sleep"])) + if kwargs.get("error", False): + print("Raising", flush=True) + raise ValueError("Too bad") + print("Finishing", flush=True) + return 12 + + +def test_fake_job(tmp_path: Path) -> None: + job: FakeJob[int] = FakeJob(job_id="12", folder=tmp_path) + repr(job) + assert not job.done(force_check=True) + # logs + assert job.stdout() is None + assert job.stderr() is None + with job.paths.stderr.open("w") as f: + f.write("blublu") + assert job.stderr() == "blublu" + # result + utils.cloudpickle_dump(("success", 12), job.paths.result_pickle) + assert job.result() == 12 + # exception + assert job.exception() is None + utils.cloudpickle_dump(("error", "blublu"), job.paths.result_pickle) + assert isinstance(job.exception(), Exception) + with pytest.raises(core.utils.FailedJobError): + job.result() + + +def test_fake_job_cancel_at_deletion(tmp_path: Path) -> None: + job: FakeJob[Any] = FakeJob(job_id="12", folder=tmp_path).cancel_at_deletion() # type: ignore + with patch("subprocess.call", return_value=None) as mock: + assert mock.call_count == 0 + del job + assert mock.call_count == 1 + + +def test_fake_executor(tmp_path: Path) -> None: + executor = FakeExecutor(folder=tmp_path) + job = executor.submit(_three_time, 8) + assert job.job_id == "12" + assert job.paths.submission_file.exists() + with utils.environment_variables(_TEST_CLUSTER_="slurm", SLURM_JOB_ID=str(job.job_id)): + submission.process_job(folder=job.paths.folder) + assert job.result() == 24 + + +def test_fake_executor_batch(tmp_path: Path) -> None: + executor = FakeExecutor(folder=tmp_path) + with executor.batch(): + job = executor.submit(_three_time, 8) + assert isinstance(job, core.DelayedJob) + assert isinstance(job, FakeJob) + with executor.batch(): # make sure we can send a new batch + job = executor.submit(_three_time, 8) + assert isinstance(job, core.DelayedJob) + assert isinstance(job, FakeJob) + # bad update + with pytest.raises(RuntimeError): + with executor.batch(): + executor.update_parameters(blublu=12) + # bad access + with pytest.raises(AttributeError): + with executor.batch(): + job = executor.submit(_three_time, 8) + assert isinstance(job, core.DelayedJob) + job.job_id # pylint: disable=pointless-statement + assert isinstance(job, core.DelayedJob) + + with executor.batch(allow_implicit_submissions=True): + job = executor.submit(_three_time, 8) + assert isinstance(job, core.DelayedJob) + job.job_id # pylint: disable=pointless-statement + assert isinstance(job, FakeJob) + assert not executor._delayed_batch + + # empty context + with pytest.warns(RuntimeWarning): + with executor.batch(): + pass + # multi context + with pytest.raises(RuntimeError): + with executor.batch(): + with executor.batch(): + job = executor.submit(_three_time, 8) + assert isinstance(job, core.DelayedJob) + assert isinstance(job, FakeJob) + + +def test_unpickling_watcher_registration(tmp_path: Path) -> None: + executor = FakeExecutor(folder=tmp_path) + job = executor.submit(_three_time, 4) + original_job_id = job._job_id + job._job_id = "007" # pylint: disable=attribute-defined-outside-init + assert job.watcher._registered == {original_job_id} # still holds the old job id + pkl = pickle.dumps(job) + newjob = pickle.loads(pkl) + assert newjob.job_id == "007" + assert newjob.watcher._registered == {original_job_id, "007"} + + +if __name__ == "__main__": + args, kwargs = [], {} # oversimplisitic parser + for argv in sys.argv[1:]: + if "=" in argv: + key, val = argv.split("=") + kwargs[key.strip("-")] = val + else: + args.append(argv) + do_nothing(*args, **kwargs) diff --git a/src/submitit/core/test_plugins.py b/src/submitit/core/test_plugins.py new file mode 100644 index 0000000..2ac0009 --- /dev/null +++ b/src/submitit/core/test_plugins.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import re +from pathlib import Path +from typing import Any, Iterator + +import pkg_resources +import pytest + +from . import core, plugins +from .job_environment import JobEnvironment + + +@pytest.mark.parametrize("env", plugins.get_job_environments().values()) +def test_env(env: JobEnvironment) -> None: + assert isinstance(env, JobEnvironment) + # We are not inside a submitit job + assert not env.activated() + assert type(env)._requeue is not JobEnvironment._requeue, "_requeue need to be overridden" + + +@pytest.mark.parametrize("ex", plugins.get_executors().values()) +def test_executors(ex: core.Executor) -> None: + assert isinstance(ex, type) + assert issubclass(ex, core.Executor) + assert ex.affinity() >= -1 + + +def test_finds_default_environments() -> None: + envs = plugins.get_job_environments() + assert len(envs) >= 3 + assert "slurm" in envs + assert "local" in envs + assert "debug" in envs + + +def test_finds_default_executors() -> None: + ex = plugins.get_executors() + assert len(ex) >= 3 + assert "slurm" in ex + assert "local" in ex + assert "debug" in ex + + +def test_job_environment_works(monkeypatch): + monkeypatch.setenv("_TEST_CLUSTER_", "slurm") + env = plugins.get_job_environment() + assert env.cluster == "slurm" + assert type(env).__name__ == "SlurmJobEnvironment" + + env2 = JobEnvironment() + assert env2.cluster == "slurm" + assert type(env2).__name__ == "SlurmJobEnvironment" + + +def test_job_environment_raises_outside_of_job() -> None: + with pytest.raises(RuntimeError, match=r"which environment.*slurm.*local.*debug"): + plugins.get_job_environment() + + +class PluginCreator: + def __init__(self, tmp_path: Path, monkeypatch): + self.tmp_path = tmp_path + self.monkeypatch = monkeypatch + + def add_plugin(self, name: str, entry_points: str, init: str): + plugin = self.tmp_path / name + plugin.mkdir(mode=0o777) + plugin_egg = plugin.with_suffix(".egg-info") + plugin_egg.mkdir(mode=0o777) + + (plugin_egg / "entry_points.txt").write_text(entry_points) + (plugin / "__init__.py").write_text(init) + + # also fix pkg_resources since it already has loaded old packages in other tests. + working_set = pkg_resources.WorkingSet([str(self.tmp_path)]) + self.monkeypatch.setattr(pkg_resources, "iter_entry_points", working_set.iter_entry_points) + + def __enter__(self) -> None: + _clear_plugin_cache() + self.monkeypatch.syspath_prepend(self.tmp_path) + + def __exit__(self, *exception: Any) -> None: + _clear_plugin_cache() + + +def _clear_plugin_cache() -> None: + plugins._get_plugins.cache_clear() + plugins.get_executors.cache_clear() + + +@pytest.fixture(name="plugin_creator") +def _plugin_creator(tmp_path: Path, monkeypatch) -> Iterator[PluginCreator]: + creator = PluginCreator(tmp_path, monkeypatch) + with creator: + yield creator + + +def test_find_good_plugin(plugin_creator: PluginCreator) -> None: + plugin_creator.add_plugin( + "submitit_good", + entry_points="""[submitit] +executor = submitit_good:GoodExecutor +job_environment = submitit_good:GoodJobEnvironment +unsupported_key = submitit_good:SomethingElse +""", + init=""" +import submitit + +class GoodExecutor(submitit.Executor): + pass + +class GoodJobEnvironment: + pass +""", + ) + + executors = plugins.get_executors().keys() + # Only the plugins declared with plugin_creator are visible. + assert set(executors) == {"good", "slurm", "local", "debug"} + + +def test_skip_bad_plugin(caplog, plugin_creator: PluginCreator) -> None: + caplog.set_level(logging.WARNING, logger="submitit") + plugin_creator.add_plugin( + "submitit_bad", + entry_points="""[submitit] +executor = submitit_bad:NonExisitingExecutor +job_environment = submitit_bad:BadEnvironment +unsupported_key = submitit_bad:SomethingElse +""", + init=""" +import submitit + +class BadEnvironment: + name = "bad" + + def __init__(self): + raise Exception("this is a bad environment") +""", + ) + + executors = plugins.get_executors().keys() + assert {"slurm", "local", "debug"} == set(executors) + assert "bad" not in executors + expected = [ + (logging.ERROR, r"'submitit_bad'.*no attribute 'NonExisitingExecutor'"), + (logging.ERROR, r"'submitit_bad'.*this is a bad environment"), + (logging.WARNING, "unsupported_key = submitit_bad:SomethingElse"), + ] + assert len(caplog.records) == len(expected) + for record, ex_record in zip(caplog.records, expected): + assert record.name == "submitit" + assert record.levelno == ex_record[0] + assert re.search(ex_record[1], record.getMessage()) diff --git a/src/submitit/core/test_utils.py b/src/submitit/core/test_utils.py new file mode 100644 index 0000000..c5799d9 --- /dev/null +++ b/src/submitit/core/test_utils.py @@ -0,0 +1,112 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import shutil +import sys +from pathlib import Path +from typing import Optional + +import pytest + +from . import utils + + +@pytest.mark.parametrize("existing_content", [None, "blublu"]) # type: ignore +def test_temporary_save_path(tmp_path: Path, existing_content: Optional[str]) -> None: + filepath = tmp_path / "save_and_move_test.txt" + if existing_content: + filepath.write_text(existing_content) + with utils.temporary_save_path(filepath) as tmp: + assert str(tmp).endswith(".txt.save_tmp") + tmp.write_text("12") + if existing_content: + assert filepath.read_text() == existing_content + assert filepath.read_text() == "12" + + +def test_temporary_save_path_error() -> None: + with pytest.raises(FileNotFoundError): + with utils.temporary_save_path("save_and_move_test"): + pass + + +def _three_time(x: int) -> int: + return 3 * x + + +def test_delayed(tmp_path: Path) -> None: + delayed = utils.DelayedSubmission(_three_time, 4) + assert not delayed.done() + assert delayed.result() == 12 + assert delayed.done() + delayed_pkl = tmp_path / "test_delayed.pkl" + delayed.dump(delayed_pkl) + delayed2 = utils.DelayedSubmission.load(delayed_pkl) + assert delayed2.done() + + +def test_environment_variable_context() -> None: + name = "ENV_VAR_TEST" + assert name not in os.environ + with utils.environment_variables(ENV_VAR_TEST="blublu"): + assert os.environ[name] == "blublu" + with utils.environment_variables(ENV_VAR_TEST="blublu2"): + assert os.environ[name] == "blublu2" + assert os.environ[name] == "blublu" + assert name not in os.environ + + +def test_slurmpaths_id_independent() -> None: + path = "test/truc/machin_%j/name" + output = utils.JobPaths.get_first_id_independent_folder(path) + assert output.name == "truc" + + +def test_archive_dev_folders(tmp_path: Path) -> None: + utils.archive_dev_folders([Path(__file__).parent], outfile=tmp_path.with_suffix(".tar.gz")) + shutil.unpack_archive(str(tmp_path.with_suffix(".tar.gz")), extract_dir=tmp_path) + assert (tmp_path / "core").exists() + + +def test_command_function() -> None: + # This will call `submitit.core.test_core.do_nothing` + command = [sys.executable, "-m", "submitit.core.test_core"] + word = "testblublu12" + output = utils.CommandFunction(command)(word) + assert output is not None + assert word in output + with pytest.raises(utils.FailedJobError, match="Too bad"): + # error=True will make `do_nothing` fail + utils.CommandFunction(command, verbose=True)(error=True) + + +def test_command_function_deadlock(executor) -> None: + code = """ +import sys; +print(sys.__stderr__) +# The goal here is to fill up the stderr pipe buffer. +for i in range({n}): + print("-" * 1024, file=sys.stdout) +print("printed {n} lines to stderr") +""" + fn1 = utils.CommandFunction([sys.executable, "-c", code.format(n=10)]) + executor.update_parameters(timeout_min=2 / 60) + j1 = executor.submit(fn1) + assert "10 lines" in j1.result() + + fn2 = utils.CommandFunction(["python", "-c", code.format(n=1000)]) + j2 = executor.submit(fn2) + assert "1000 lines" in j2.result() + + +def test_jobpaths(tmp_path: Path) -> None: + assert utils.JobPaths(tmp_path, "123").stdout == tmp_path / "123_0_log.out" + assert utils.JobPaths(tmp_path, "123", 1).stdout == tmp_path / "123_1_log.out" + assert ( + utils.JobPaths(tmp_path / "array-%A-index-%a", "456_3").stdout + == tmp_path / "array-456-index-3" / "456_3_0_log.out" + ) diff --git a/src/submitit/core/utils.py b/src/submitit/core/utils.py new file mode 100644 index 0000000..ab0bd01 --- /dev/null +++ b/src/submitit/core/utils.py @@ -0,0 +1,355 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import contextlib +import io +import itertools +import os +import pickle +import select +import shutil +import subprocess +import sys +import tarfile +import typing as tp +from pathlib import Path + +import cloudpickle + + +@contextlib.contextmanager +def environment_variables(**kwargs: tp.Any) -> tp.Iterator[None]: + backup = {x: os.environ[x] for x in kwargs if x in os.environ} + os.environ.update({x: str(y) for x, y in kwargs.items()}) + try: + yield + finally: + for x in kwargs: + del os.environ[x] + os.environ.update(backup) + + +class UncompletedJobError(RuntimeError): + """Job is uncomplete: either unfinished or failed""" + + +class FailedJobError(UncompletedJobError): + """Job failed during processing""" + + +class FailedSubmissionError(RuntimeError): + """Job Submission failed""" + + +class JobPaths: + """Creates paths related to the slurm job and its submission""" + + def __init__( + self, folder: tp.Union[Path, str], job_id: tp.Optional[str] = None, task_id: tp.Optional[int] = None + ) -> None: + self._folder = Path(folder).expanduser().absolute() + self.job_id = job_id + self.task_id = task_id or 0 + + @property + def folder(self) -> Path: + return self._format_id(self._folder) + + @property + def submission_file(self) -> Path: + if self.job_id and "_" in self.job_id: + # We only have one submission file per job array + return self._format_id(self.folder / "%A_submission.sh") + return self._format_id(self.folder / "%j_submission.sh") + + @property + def submitted_pickle(self) -> Path: + return self._format_id(self.folder / "%j_submitted.pkl") + + @property + def result_pickle(self) -> Path: + return self._format_id(self.folder / "%j_%t_result.pkl") + + @property + def stderr(self) -> Path: + return self._format_id(self.folder / "%j_%t_log.err") + + @property + def stdout(self) -> Path: + return self._format_id(self.folder / "%j_%t_log.out") + + def _format_id(self, path: tp.Union[Path, str]) -> Path: + """Replace id tag by actual id if available""" + if self.job_id is None: + return Path(path) + replaced_path = str(path).replace("%j", str(self.job_id)).replace("%t", str(self.task_id)) + array_id, *array_index = str(self.job_id).split("_", 1) + if "%a" in replaced_path: + if len(array_index) != 1: + raise ValueError("%a is in the folder path but this is not a job array") + replaced_path = replaced_path.replace("%a", array_index[0]) + return Path(replaced_path.replace("%A", array_id)) + + def move_temporary_file(self, tmp_path: tp.Union[Path, str], name: str) -> None: + self.folder.mkdir(parents=True, exist_ok=True) + Path(tmp_path).rename(getattr(self, name)) + + @staticmethod + def get_first_id_independent_folder(folder: tp.Union[Path, str]) -> Path: + """Returns the closest folder which is id independent""" + parts = Path(folder).expanduser().absolute().parts + tags = ["%j", "%t", "%A", "%a"] + indep_parts = itertools.takewhile(lambda x: not any(tag in x for tag in tags), parts) + return Path(*indep_parts) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.folder})" + + +class DelayedSubmission: + """Object for specifying the function/callable call to submit and process later. + This is only syntactic sugar to make sure everything is well formatted: + If what you want to compute later is func(*args, **kwargs), just instanciate: + DelayedSubmission(func, *args, **kwargs). + It also provides convenient tools for dumping and loading. + """ + + def __init__(self, function: tp.Callable[..., tp.Any], *args: tp.Any, **kwargs: tp.Any) -> None: + self.function = function + self.args = args + self.kwargs = kwargs + self._result: tp.Any = None + self._done = False + self._timeout_min: int = 0 + self._timeout_countdown: int = 0 # controlled in submission and execution + + def result(self) -> tp.Any: + if self._done: + return self._result + + self._result = self.function(*self.args, **self.kwargs) + self._done = True + return self._result + + def done(self) -> bool: + return self._done + + def dump(self, filepath: tp.Union[str, Path]) -> None: + cloudpickle_dump(self, filepath) + + def set_timeout(self, timeout_min: int, max_num_timeout: int) -> None: + self._timeout_min = timeout_min + self._timeout_countdown = max_num_timeout + + @classmethod + def load(cls: tp.Type["DelayedSubmission"], filepath: tp.Union[str, Path]) -> "DelayedSubmission": + obj = pickle_load(filepath) + # following assertion is relaxed compared to isinstance, to allow flexibility + # (Eg: copying this class in a project to be able to have checkpointable jobs without adding submitit as dependency) + assert obj.__class__.__name__ == cls.__name__, f"Loaded object is {type(obj)} but should be {cls}." + return obj # type: ignore + + def _checkpoint_function(self) -> tp.Optional["DelayedSubmission"]: + checkpoint = getattr(self.function, "__submitit_checkpoint__", None) + if checkpoint is None: + checkpoint = getattr(self.function, "checkpoint", None) + if checkpoint is None: + return None + return checkpoint(*self.args, **self.kwargs) # type: ignore + + +@contextlib.contextmanager +def temporary_save_path(filepath: tp.Union[Path, str]) -> tp.Iterator[Path]: + """Yields a path where to save a file and moves it + afterward to the provided location (and replaces any + existing file) + This is useful to avoid processes monitoring the filepath + to break if trying to read when the file is being written. + + Note + ---- + The temporary path is the provided path appended with .save_tmp + """ + filepath = Path(filepath) + tmppath = filepath.with_suffix(filepath.suffix + ".save_tmp") + assert not tmppath.exists(), "A temporary saved file already exists." + yield tmppath + if not tmppath.exists(): + raise FileNotFoundError("No file was saved at the temporary path.") + if filepath.exists(): + os.remove(filepath) + os.rename(tmppath, filepath) + + +def archive_dev_folders( + folders: tp.List[tp.Union[str, Path]], outfile: tp.Optional[tp.Union[str, Path]] = None +) -> Path: + """Creates a tar.gz file with all provided folders""" + assert isinstance(folders, (list, tuple)), "Only lists and tuples of folders are allowed" + if outfile is None: + outfile = "_dev_folders_.tar.gz" + outfile = Path(outfile) + assert str(outfile).endswith(".tar.gz"), "Archive file must have extension .tar.gz" + with tarfile.TarFile(outfile, mode="w") as tf: + for folder in folders: + tf.add(str(folder), arcname=Path(folder).name) + return outfile + + +def copy_par_file(par_file: tp.Union[str, Path], folder: tp.Union[str, Path]) -> Path: + """Copy the par (or xar) file in the folder + + Parameter + --------- + par_file: str/Path + Par file generated by buck + folder: str/Path + folder where the par file must be copied + + Returns + ------- + Path + Path of the copied .par file + """ + par_file = Path(par_file).expanduser().absolute() + folder = Path(folder).expanduser().absolute() + folder.mkdir(parents=True, exist_ok=True) + dst_name = folder / par_file.name + shutil.copy2(par_file, dst_name) + return dst_name + + +def pickle_load(filename: tp.Union[str, Path]) -> tp.Any: + # this is used by cloudpickle as well + with open(filename, "rb") as ifile: + return pickle.load(ifile) + + +def cloudpickle_dump(obj: tp.Any, filename: tp.Union[str, Path]) -> None: + with open(filename, "wb") as ofile: + cloudpickle.dump(obj, ofile, pickle.HIGHEST_PROTOCOL) + + +# pylint: disable=too-many-locals +def copy_process_streams( + process: subprocess.Popen, stdout: io.StringIO, stderr: io.StringIO, verbose: bool = False +): + """ + Reads the given process stdout/stderr and write them to StringIO objects. + Make sure that there is no deadlock because of pipe congestion. + If `verbose` the process stdout/stderr are also copying to the interpreter stdout/stderr. + """ + + def raw(stream: tp.Optional[tp.IO[bytes]]) -> tp.IO[bytes]: + assert stream is not None + if isinstance(stream, io.BufferedIOBase): + stream = stream.raw + return stream + + p_stdout, p_stderr = raw(process.stdout), raw(process.stderr) + stream_by_fd: tp.Dict[int, tp.Tuple[tp.IO[bytes], io.StringIO, tp.IO[str]]] = { + p_stdout.fileno(): (p_stdout, stdout, sys.stdout), + p_stderr.fileno(): (p_stderr, stderr, sys.stderr), + } + fds = list(stream_by_fd.keys()) + poller = select.poll() + for fd in stream_by_fd: + poller.register(fd, select.POLLIN | select.POLLPRI) + while fds: + # `poll` syscall will wait until one of the registered file descriptors has content. + ready = poller.poll() + for fd, _ in ready: + p_stream, string, std = stream_by_fd[fd] + raw_buf = p_stream.read(2**16) + if not raw_buf: + fds.remove(fd) + poller.unregister(fd) + continue + buf = raw_buf.decode() + string.write(buf) + string.flush() + if verbose: + std.write(buf) + std.flush() + + +# used in "_core", so cannot be in "helpers" +class CommandFunction: + """Wraps a command as a function in order to make sure it goes through the + pipeline and notify when it is finished. + The output is a string containing everything that has been sent to stdout. + WARNING: use CommandFunction only if you know the output won't be too big ! + Otherwise use subprocess.run() that also streams the outputto stdout/stderr. + + Parameters + ---------- + command: list + command to run, as a list + verbose: bool + prints the command and stdout at runtime + cwd: Path/str + path to the location where the command must run from + + Returns + ------- + str + Everything that has been sent to stdout + """ + + def __init__( + self, + command: tp.List[str], + verbose: bool = True, + cwd: tp.Optional[tp.Union[str, Path]] = None, + env: tp.Optional[tp.Dict[str, str]] = None, + ) -> None: + if not isinstance(command, list): + raise TypeError("The command must be provided as a list") + self.command = command + self.verbose = verbose + self.cwd = None if cwd is None else str(cwd) + self.env = env + + def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> str: + """Call the cammand line with addidional arguments + The keyword arguments will be sent as --{key}={val} + The logs bufferized. They will be printed if the job fails, or sent as output of the function + Errors are provided with the internal stderr. + """ + full_command = ( + self.command + [str(x) for x in args] + [f"--{x}={y}" for x, y in kwargs.items()] + ) # TODO bad parsing + if self.verbose: + print(f"The following command is sent: \"{' '.join(full_command)}\"") + with subprocess.Popen( + full_command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + cwd=self.cwd, + env=self.env, + ) as process: + stdout_buffer = io.StringIO() + stderr_buffer = io.StringIO() + + try: + copy_process_streams(process, stdout_buffer, stderr_buffer, self.verbose) + except Exception as e: + process.kill() + process.wait() + raise FailedJobError("Job got killed for an unknown reason.") from e + stdout = stdout_buffer.getvalue().strip() + stderr = stderr_buffer.getvalue().strip() + retcode = process.wait() + if stderr and (retcode and not self.verbose): + # We don't print is self.verbose, as it already happened before. + print(stderr, file=sys.stderr) + if retcode: + subprocess_error = subprocess.CalledProcessError( + retcode, process.args, output=stdout, stderr=stderr + ) + raise FailedJobError(stderr) from subprocess_error + return stdout diff --git a/src/submitit/helpers.py b/src/submitit/helpers.py new file mode 100644 index 0000000..e5b966a --- /dev/null +++ b/src/submitit/helpers.py @@ -0,0 +1,446 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import collections +import contextlib +import datetime +import itertools +import os +import random +import shutil +import subprocess +import tempfile +import time +import typing as tp +from pathlib import Path + +# pylint: disable=unused-import +# import DelayedSubmission and CommandFunction to populate helpers namespace +from .core import core +from .core.job_environment import JobEnvironment +from .core.utils import CommandFunction as CommandFunction # noqa +from .core.utils import DelayedSubmission as DelayedSubmission # noqa +from .core.utils import environment_variables as environment_variables # noqa + + +class Checkpointable: + """Derived callable classes are requeued after timeout with their current + state dumped at checkpoint. + + __call__ method must be implemented to make your class a callable. + + Note + ---- + The following implementation of the checkpoint method resubmits the full current + state of the callable (self) with the initial argument. You may want to replace the method to + curate the state (dump a neural network to a standard format and remove it from + the state so that not to pickle it) and change/remove the initial parameters. + """ + + # pylint: disable=unused-argument + def __new__(cls, *args, **kwargs): + instance = super().__new__(cls) + assert callable( + instance + ), f"Class {cls.__name__} is marked as Checkpointable but doesn't have a __call__ method. Please add a __call__ method." + return instance + + def checkpoint(self, *args: tp.Any, **kwargs: tp.Any) -> DelayedSubmission: + """Resubmits the same callable with the same arguments""" + # The DelayedSubmission class goal is only to register and format + # the arguments of the call "self(*args, **kwargs)" for submission to slurm + return DelayedSubmission(self, *args, **kwargs) # type: ignore + + +class FunctionSequence(Checkpointable): + """This is for gathering several estimations into one function, which + will return the sequence of outputs. + Also this "function" is stateful, hence it can be stopped, and recovered, + which is useful when job can be preempted. + + Usage + ----- + func = FunctionSequence() + func.add(my_function1, arg1, kwarg1=value_kwarg1) + func.add(my_function2, arg1, arg2) + result1, result2 = func() + + Note + ---- + This function is checkpointable because: + - it derives from Checkpointable + - it keeps DelayedSubmission objects as attribute, which in turn store the + results of the computation in memory once they are computed. So at checkpoint + time, those results will be saved, and only the non-computed results + will be computed once the job restarts. + """ + + def __init__(self, verbose: bool = False) -> None: + self.verbose = verbose + self.delayed_functions: tp.List[DelayedSubmission] = [] + + def add( + self, func: tp.Callable[..., tp.Any], *args: tp.Any, **kwargs: tp.Any + ) -> None: + self.delayed_functions.append(DelayedSubmission(func, *args, **kwargs)) + + def __len__(self) -> int: + return len(self.delayed_functions) + + def __iter__(self) -> tp.Iterator[DelayedSubmission]: + return iter(self.delayed_functions) + + def __call__(self) -> tp.List[tp.Any]: # pylint: disable=arguments-differ + if self.verbose: + done = sum(f.done() for f in self) # those were computed before checkpoint + print(f"Starting from {done}/{len(self.delayed_functions)}", flush=True) + return [ + f.result() for f in self.delayed_functions + ] # results all results one by one (by running the functions if not already done) + + +def as_completed( + jobs: tp.Sequence[core.Job[core.R]], + timeout: tp.Optional[tp.Union[int, float]] = None, + poll_frequency: float = 10, +) -> tp.Iterator[core.Job[core.R]]: + """ + Yields jobs as they complete (finished, failed or were cancelled). + Raises a TimeoutError if the result isn’t available after timeout seconds. + timeout can be an int or float. If timeout is not specified or None, there is no + limit to the wait time. + + Parameters + ---------- + jobs: list + Jobs instances + + timeout: int/float + Maximum time (in sec) to wait for jobs completion + + poll_frequency: float + Frequency in second at which we check job status. + + Yields + ------ + Job + The next completed job + """ + start = time.time() + jobs_done: tp.Set[int] = set() + while True: + if timeout is not None and time.time() - start > timeout: + raise TimeoutError + for i, job in enumerate(jobs): + if i in jobs_done: + continue + if job.done(): + jobs_done.add(i) + yield job + if len(jobs_done) == len(jobs): + break + time.sleep(poll_frequency) + + +def run_cmd(str_args, **kwargs): + return subprocess.check_output(str_args, **kwargs).decode("utf-8").strip() + + +class RsyncSnapshot: + """Takes a snapshot of the git repository that the script lives in. + + This ensures that remote jobs always use the code from when they are scheduled + and not the code from when they are launched / re-started. + + + Parameters + ---------- + snapshot_dir: Path + A path to where the snapshot should be created + with_submodules: bool + Whether or not submodules should be included in the snapshot + exclude: Sequence[str] + An optional list of patterns to exclude from the snapshot + include: Sequence[str] + A list of relative file names to include from the snapshot. + Useful for .so or other build artifacts that are genarally not tracked by git. + + Note + ---- + - Only files that are checked in to the repository are included in the snapshot. + If you have experimental code that you would like to include in the snapshot, + you'll need to `git add` the file first for it to be included, or use `include` arg. + """ + + def __init__( + self, + snapshot_dir: Path, + root_dir: Path = None, + with_submodules: bool = False, + exclude: tp.Sequence[str] = (), + include: tp.Sequence[str] = (), + ): + self.available(throw=True) + self.snapshot_dir = Path(snapshot_dir) + self.root_dir = root_dir or run_cmd(["git", "rev-parse", "--show-toplevel"]) + self.original_dir = Path.cwd() + self.with_submodules = with_submodules + self.exclude = exclude + self.include = include + + @staticmethod + def available(throw: bool = False) -> bool: + if not shutil.which("rsync"): + if throw: + raise RuntimeError("RsyncSnapshot requires rsync to be installed.") + return False + return True + + def __enter__(self) -> None: + self.original_dir = Path.cwd() + # Get the repository root + root_dir = str(self.root_dir) + sub = "--recurse-submodules" if self.with_submodules else "-s" + # Make a shallow git clone + if not self.snapshot_dir.exists(): + self.snapshot_dir.parent.mkdir(parents=True, exist_ok=True) + subprocess.check_call( + [ + "git", + "clone", + "--depth=2", + f"file://{root_dir}", + str(self.snapshot_dir), + ] + ) + + # Get a list of all the checked in files that we can pass to rsync + # Is Rsync faster than a `git pull` ? + with tempfile.NamedTemporaryFile() as tfile: + # https://stackoverflow.com/a/51689219/4876946 + run_cmd( + f"git ls-files {sub} | grep -v ^16 | cut -f2- > {tfile.name}", + cwd=root_dir, + shell=True, + ) + exclude = list( + itertools.chain.from_iterable( + ("--exclude", pat) for pat in self.exclude + ) + ) + with open(tfile.name, "a", encoding="utf8") as o: + for inc in self.include: + print(inc, file=o) + run_cmd( + [ + "rsync", + "-a", + "--files-from", + tfile.name, + root_dir, + str(self.snapshot_dir), + ] + + exclude + ) + os.chdir(self.snapshot_dir) + + def __exit__(self, *args): + os.chdir(self.original_dir) + + +def _default_custom_logging( + monitoring_start_time: float, n_jobs: int, state_jobs: tp.Dict[str, tp.Set[int]] +): + run_time = time.time() - monitoring_start_time + date_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + failed_job_indices = sorted(state_jobs["FAILED"]) + n_chars = len(str(n_jobs)) + + print( + f"[{date_time}] Launched {int(run_time / 60)} minutes ago,", + f"{len(state_jobs['RUNNING']):{n_chars}}/{n_jobs} jobs running,", + f"{len(failed_job_indices):{n_chars}}/{n_jobs} jobs failed,", + f"{len(state_jobs['DONE']) - len(failed_job_indices):{n_chars}}/{n_jobs} jobs done", + flush=True, + ) + + if len(failed_job_indices) > 0: + print(f"[{date_time}] Failed jobs, indices {failed_job_indices}", flush=True) + + +def monitor_jobs( + jobs: tp.Sequence[core.Job[core.R]], + poll_frequency: float = 30, + test_mode: bool = False, + custom_logging: tp.Callable = _default_custom_logging, +) -> None: + """Continuously monitors given jobs until they are all done or failed. + + Parameters + ---------- + jobs: List[Jobs] + A list of jobs to monitor + poll_frequency: int + The time (in seconds) between two refreshes of the monitoring. + Can't be inferior to 30s. + test_mode: bool + If in test mode, we do not check the length of poll_frequency + """ + + if not test_mode: + assert ( + poll_frequency >= 30 + ), "You can't refresh too often (>= 30s) to avoid overloading squeue" + + n_jobs = len(jobs) + if n_jobs == 0: + print("There are no jobs to monitor") + return + + job_arrays = ", ".join( + sorted(set(str(job.job_id).split("_", 1)[0] for job in jobs)) + ) + print(f"Monitoring {n_jobs} jobs from job arrays {job_arrays} \n") + + monitoring_start_time = time.time() + while True: + if not test_mode: + jobs[0].get_info(mode="force") # Force update once to sync the state + state_jobs = collections.defaultdict(set) + for i, job in enumerate(jobs): + state_jobs[job.state.upper()].add(i) + if job.done(): + state_jobs["DONE"].add(i) + + failed_job_indices = sorted(state_jobs["FAILED"]) + if len(state_jobs["DONE"]) == len(jobs): + print( + f"All jobs finished, jobs with indices {failed_job_indices} failed", + flush=True, + ) + break + + custom_logging(monitoring_start_time, n_jobs, state_jobs) + time.sleep(poll_frequency) + + print( + f"Whole process is finished, took {int((time.time() - monitoring_start_time) / 60)} minutes" + ) + + +@contextlib.contextmanager +def clean_env() -> tp.Iterator[None]: + """Removes slurm and submitit related environment variables so as to avoid interferences + when submiting a new job from a job. + + Note + ---- + A slurm job submitted from within a slurm job inherits some of its attributes, which may + be confusing a cause weird gres errors (or pytorch distributed). + Submitting within this context should prevent this. + + Usage + ----- + with submitit.helpers.clean_env(): + executor.submit(...) + """ + distrib_names = ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", + ) + cluster_env = { + x: os.environ.pop(x) + for x in os.environ + if x.startswith(("SLURM_", "SUBMITIT_")) or x in distrib_names + } + try: + yield + finally: + os.environ.update(cluster_env) + + +class TorchDistributedEnvironment: + def __init__(self) -> None: + """Construct a class holding the parameters required to properly setup + PyTorch distributed (with the default env:// initialization method). + + Examples + -------- + >>> dist_env = TorchDistributedEnvironment().export() + >>> torch.distributed.init_process_group(backend="nccl") + >>> print(f"master: {dist_env.master_addr}:{dist_env.master_port}") + """ + self._job_env = JobEnvironment() + self.master_addr = self._job_env.hostnames[0] + self.master_port = self._get_master_port() + self.rank = self._job_env.global_rank + self.world_size = self._job_env.num_tasks + self.local_rank = self._job_env.local_rank + self.local_world_size = self._job_env.num_tasks // self._job_env.num_nodes + + def _get_master_port(self) -> int: + # MIN_MASTER_PORT, MAX_MASTER_PORT = (1023, 65535) + MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000) + + master_port_str = os.environ.get("MASTER_PORT") + if master_port_str is None: + rng = random.Random(self._job_env.job_id) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + master_port = int(master_port_str) + # assert MIN_MASTER_PORT <= master_port <= MIN_MASTER_PORT + return master_port + + def export( + self, + set_cuda_visible_devices: bool = True, + overwrite: bool = False, + ) -> "TorchDistributedEnvironment": + """Export all the environment variables required to properly setup + PyTorch distributed (with the default env:// initialization method) i.e. + MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE (to which LOCAL_RANK and + LOCAL_WORLD_SIZE are added). + + Parameter + ---------- + set_cuda_visible_device: bool + if True, updates CUDA_VISIBLE_DEVICES to use only the device + matching the local rank. + overwrite: bool + if True, overwrites the environment variables if they exist; + this can be useful when launching a job from another job. + + Returns + -------- + TorchDistributedEnvironment + the current instance + """ + # See the "Environment variable initialization" section from + # https://pytorch.org/docs/stable/distributed.html for the complete list of + # environment variables required for the env:// initialization method. + env_vars = { + "MASTER_ADDR": self.master_addr, + "MASTER_PORT": str(self.master_port), + "RANK": str(self.rank), + "WORLD_SIZE": str(self.world_size), + "LOCAL_RANK": str(self.local_rank), # Not required + "LOCAL_WORLD_SIZE": str(self.local_world_size), # Not required + } + if not overwrite: + for key in env_vars: + if key in os.environ: + raise RuntimeError( + f"Cannot export environment variables as {key} is already set" + ) + # Note: CUDA_VISIBLE_DEVICES may already be set with all available GPUs + if set_cuda_visible_devices: + env_vars["CUDA_VISIBLE_DEVICES"] = str(self.local_rank) + os.environ.update(env_vars) + return self diff --git a/src/submitit/local/__init__.py b/src/submitit/local/__init__.py new file mode 100644 index 0000000..602d268 --- /dev/null +++ b/src/submitit/local/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# diff --git a/src/submitit/local/_local.py b/src/submitit/local/_local.py new file mode 100644 index 0000000..9b1125a --- /dev/null +++ b/src/submitit/local/_local.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import sys +from pathlib import Path + +from .local import Controller + +if __name__ == "__main__": + assert len(sys.argv) == 2, "Usage: _local.py " + # most arguments are read from environment variables. + controller = Controller(Path(sys.argv[1])) + controller.run() diff --git a/src/submitit/local/debug.py b/src/submitit/local/debug.py new file mode 100644 index 0000000..45e5c79 --- /dev/null +++ b/src/submitit/local/debug.py @@ -0,0 +1,155 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import os +import typing as tp +from pathlib import Path +from typing import Dict, List, Optional, Union + +from ..core.core import Executor, InfoWatcher, Job, R +from ..core.job_environment import JobEnvironment +from ..core.utils import DelayedSubmission, UncompletedJobError + + +class DebugInfoWatcher(InfoWatcher): + # pylint: disable=abstract-method + def register_job(self, job_id: str) -> None: + pass + + +class DebugJobEnvironment(JobEnvironment): + _env = { + "job_id": "SUBMITIT_DEBUG_JOB_ID", + # We don't set those, and rely on the default values from JobEnvironment + "num_nodes": "SUBMITIT_DEBUG_NOT_SET", + "num_tasks": "SUBMITIT_DEBUG_NOT_SET", + "node": "SUBMITIT_DEBUG_NOT_SET", + "global_rank": "SUBMITIT_DEBUG_NOT_SET", + "local_rank": "SUBMITIT_DEBUG_NOT_SET", + } + + def activated(self) -> bool: + return "SUBMITIT_DEBUG_JOB_ID" in os.environ + + def _requeue(self, countdown: int) -> None: + pass + + +# pylint in python 3.6 is confused by generics. +class DebugJob(Job[R]): + watcher = DebugInfoWatcher() + + def __init__(self, folder: Path, submission: DelayedSubmission) -> None: + job_id = f"DEBUG_{id(submission)}" + super().__init__(folder=folder, job_id=job_id) + self._submission = submission + self.cancelled = False + self.environ = dict(os.environ) + self.environ["SUBMITIT_DEBUG_JOB_ID"] = self.job_id + + def submission(self) -> DelayedSubmission: + return self._submission + + @property + def num_tasks(self) -> int: + return 1 + + def cancel(self, check: bool = True) -> None: # pylint: disable=unused-argument + self.cancelled = True + + def _check_not_cancelled(self) -> None: + if self.cancelled: + raise UncompletedJobError(f"Job {self} was cancelled.") + + def results(self) -> List[R]: + self._check_not_cancelled() + if self._submission.done(): + return [self._submission._result] + + environ_backup = dict(os.environ) + # Restore os.environ from job creation time. + os.environ.clear() + os.environ.update(self.environ) + + root_logger = logging.getLogger("") + self.paths.stdout.parent.mkdir(exist_ok=True, parents=True) + stdout_handler = logging.FileHandler(self.paths.stdout) + stdout_handler.setLevel(logging.DEBUG) + stderr_handler = logging.FileHandler(self.paths.stderr) + stderr_handler.setLevel(logging.WARNING) + root_logger.addHandler(stdout_handler) + root_logger.addHandler(stderr_handler) + root_logger.warning( + f"Logging is written both to stderr/stdout and to {self.paths.stdout}/err. " + "But call to print will only appear in the console." + ) + try: + return [self._submission.result()] + except Exception as e: + print(e) + # Try to mimic `breakpoint()` behavior + # pylint: disable=import-outside-toplevel + if os.environ.get("PYTHONBREAKPOINT", "").startswith("ipdb"): + import ipdb # pylint: disable=import-error + + ipdb.post_mortem() + else: + import pdb + + pdb.post_mortem() + raise + finally: + os.environ.clear() + os.environ.update(environ_backup) + root_logger.removeHandler(stdout_handler) + root_logger.removeHandler(stderr_handler) + + def exception(self) -> Optional[BaseException]: # type: ignore + self._check_not_cancelled() + try: + self._submission.result() + return None + except Exception as e: + # Note that we aren't wrapping the error contrary to what is done in + # other Executors. It makes the stacktrace smaller and debugging easier. + return e + + def wait(self) -> None: + # forces execution. + self.results() + + def done(self, force_check: bool = False) -> bool: # pylint: disable=unused-argument + # forces execution, in case the client is waiting on it to become True. + self.results() + return self._submission.done() + + @property + def state(self) -> str: + if self._submission.done(): + return "DONE" + if self.cancelled: + return "CANCELLED" + return "QUEUED" + + def get_info(self, mode: str = "force") -> Dict[str, str]: # pylint: disable=unused-argument + return {"STATE": self.state} + + def __del__(self) -> None: + # Skip parent code + return + + +class DebugExecutor(Executor): + job_class = DebugJob + + def __init__(self, folder: Union[str, Path]): + super().__init__(folder) + + def _internal_process_submissions( + self, delayed_submissions: tp.List[DelayedSubmission] + ) -> tp.List[Job[tp.Any]]: + return [DebugJob(self.folder, ds) for ds in delayed_submissions] diff --git a/src/submitit/local/local.py b/src/submitit/local/local.py new file mode 100644 index 0000000..78f8a1a --- /dev/null +++ b/src/submitit/local/local.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import shlex +import signal +import subprocess +import sys +import time +from pathlib import Path +from typing import IO, Any, Dict, List, Optional, Sequence, Union + +from ..core import core, job_environment, logger, utils +from ..core.core import R + +# pylint: disable-msg=too-many-arguments +VALID_KEYS = { + "timeout_min", + "gpus_per_node", + "tasks_per_node", + "signal_delay_s", + "visible_gpus", +} + +LOCAL_REQUEUE_RETURN_CODE = 144 + + +class LocalJob(core.Job[R]): + def __init__( + self, + folder: Union[Path, str], + job_id: str, + tasks: Sequence[int] = (0,), + process: Optional["subprocess.Popen['bytes']"] = None, + ) -> None: + super().__init__(folder, job_id, tasks) + self._cancel_at_deletion = False + self._process = process + # downcast sub-jobs to get proper typing + self._sub_jobs: Sequence["LocalJob[R]"] = self._sub_jobs + for sjob in self._sub_jobs: + sjob._process = process + + def done(self, force_check: bool = False) -> bool: # pylint: disable=unused-argument + """Override to avoid using the watcher""" + assert self._process is not None + return self._process.poll() is not None + + @property + def state(self) -> str: + """State of the job""" + try: + return self.get_info().get("jobState", "unknown") + # I don't what is the exception returned and it's hard to reproduce + except Exception: + return "UNKNOWN" + + def get_info(self, mode: str = "force") -> Dict[str, str]: # pylint: disable=unused-argument + """Returns information about the job as a dict.""" + assert self._process is not None + poll = self._process.poll() + if poll is None: + state = "RUNNING" + elif poll < 0: + state = "INTERRUPTED" + else: + state = "FINISHED" + return {"jobState": state} + + def cancel(self, check: bool = True) -> None: # pylint: disable=unused-argument + assert self._process is not None + self._process.send_signal(signal.SIGINT) + + def _interrupt(self) -> None: + """Sends preemption / timeout signal to the job (for testing purpose)""" + assert self._process is not None + self._process.send_signal(LocalJobEnvironment._usr_sig()) + + def __del__(self) -> None: + if self._cancel_at_deletion: + if not self.get_info().get("jobState") == "FINISHED": + self.cancel(check=False) + + +class LocalJobEnvironment(job_environment.JobEnvironment): + _env = { + "job_id": "SUBMITIT_LOCAL_JOB_ID", + "num_tasks": "SUBMITIT_LOCAL_NTASKS", + "num_nodes": "SUBMITIT_LOCAL_JOB_NUM_NODES", + "node": "SUBMITIT_LOCAL_NODEID", + "global_rank": "SUBMITIT_LOCAL_GLOBALID", + "local_rank": "SUBMITIT_LOCAL_LOCALID", + } + + def _requeue(self, countdown: int) -> None: + jid = self.job_id + logger.get_logger().info(f"Requeued job {jid} ({countdown} remaining timeouts)") + sys.exit(LOCAL_REQUEUE_RETURN_CODE) # should help noticing if need requeuing + + +class LocalExecutor(core.PicklingExecutor): + """Local job executor + This class is used to hold the parameters to run a job locally. + In practice, it will create a bash file in the specified directory for each job, + and pickle the task function and parameters. At completion, the job will also pickle + the output. Logs are also dumped in the same directory. + + The submission file spawn several processes (one per task), with a timeout. + + + Parameters + ---------- + folder: Path/str + folder for storing job submission/output and logs. + + Note + ---- + - be aware that the log/output folder will be full of logs and pickled objects very fast, + it may need cleaning. + - use update_parameters to specify custom parameters (n_gpus etc...). + """ + + job_class = LocalJob + + def __init__(self, folder: Union[str, Path], max_num_timeout: int = 3) -> None: + super().__init__(folder, max_num_timeout=max_num_timeout) + # preliminary check + indep_folder = utils.JobPaths.get_first_id_independent_folder(self.folder) + indep_folder.mkdir(parents=True, exist_ok=True) + + def _internal_update_parameters(self, **kwargs: Any) -> None: + """Update the parameters of the Executor. + + Valid parameters are: + - timeout_min (float) + - gpus_per_node (int) + - visible_gpus (Sequence[int]) + - tasks_per_node (int) + - nodes (int). Must be 1 if specified + - signal_delay_s (int): signal (lately: USR2) delay before timeout + + Other parameters are ignored + """ + if kwargs.get("nodes", 0) > 1: + raise ValueError("LocalExecutor can use only one node. Use nodes=1") + gpus_requested = kwargs.get("gpus_per_node", 0) + visible_gpus = kwargs.get("visible_gpus", ()) + if not isinstance(visible_gpus, Sequence): + raise ValueError( + f"Provided visible_gpus={visible_gpus} is not an instance of Sequence." + ) + if not all(isinstance(x, int) for x in visible_gpus): + raise ValueError( + f"Provided visible_gpus={visible_gpus} contains an element that is not an int." + ) + if len(visible_gpus) > 0 and gpus_requested > len(visible_gpus): + raise ValueError( + f"{gpus_requested} gpus requested, but only {visible_gpus} were specified visible." + ) + super()._internal_update_parameters(**kwargs) + + def _submit_command(self, command: str) -> LocalJob[R]: + # Override this, because the implementation is simpler than for clusters like Slurm + # Only one node is supported for local executor. + ntasks = self.parameters.get("tasks_per_node", 1) + n_gpus = self.parameters.get("gpus_per_node", 0) + visible_gpus = self.parameters.get("visible_gpus", ()) + gpus = range(n_gpus) if visible_gpus == () else visible_gpus[:n_gpus] + process = start_controller( + folder=self.folder, + command=command, + tasks_per_node=ntasks, + cuda_devices=",".join(str(k) for k in gpus), + timeout_min=self.parameters.get("timeout_min", 2.0), + signal_delay_s=self.parameters.get("signal_delay_s", 30), + stderr_to_stdout=self.parameters.get("stderr_to_stdout", False), + ) + job: LocalJob[R] = LocalJob( + folder=self.folder, + job_id=str(process.pid), + process=process, + tasks=list(range(ntasks)), + ) + return job + + @property + def _submitit_command_str(self) -> str: + return " ".join( + [ + shlex.quote(sys.executable), + "-u -m submitit.core._submit", + shlex.quote(str(self.folder)), + ] + ) + + def _num_tasks(self) -> int: + nodes: int = 1 + tasks_per_node: int = self.parameters.get("tasks_per_node", 1) + return nodes * tasks_per_node + + def _make_submission_file_text(self, command: str, uid: str) -> str: + return "" + + @staticmethod + def _get_job_id_from_submission_command(string: Union[bytes, str]) -> str: + # Not used, but need an implementation + return "0" + + def _make_submission_command(self, submission_file_path: Path) -> List[str]: + # Not used, but need an implementation + return [] + + +def start_controller( + folder: Path, + command: str, + tasks_per_node: int = 1, + cuda_devices: str = "", + timeout_min: float = 5.0, + signal_delay_s: int = 30, + stderr_to_stdout: bool = False, +) -> "subprocess.Popen['bytes']": + """Starts a job controller, which is expected to survive the end of the python session.""" + env = dict(os.environ) + env.update( + SUBMITIT_LOCAL_NTASKS=str(tasks_per_node), + SUBMITIT_LOCAL_COMMAND=command, + SUBMITIT_LOCAL_TIMEOUT_S=str(int(60 * timeout_min)), + SUBMITIT_LOCAL_SIGNAL_DELAY_S=str(int(signal_delay_s)), + SUBMITIT_LOCAL_NODEID="0", + SUBMITIT_LOCAL_JOB_NUM_NODES="1", + SUBMITIT_STDERR_TO_STDOUT="1" if stderr_to_stdout else "", + SUBMITIT_EXECUTOR="local", + CUDA_VISIBLE_DEVICES=cuda_devices, + ) + # The LocalJob will be responsible to polling and ending this process. + # pylint: disable=consider-using-with + process = subprocess.Popen( + [sys.executable, "-m", "submitit.local._local", str(folder)], + shell=False, + env=env, + ) + return process + + +class Controller: + """This controls a job: + - instantiate each of the tasks + - sends timeout signal + - stops all tasks if one of them finishes + - cleans up the tasks/closes log files when deleted + """ + + # pylint: disable=too-many-instance-attributes + + def __init__(self, folder: Path): + self.ntasks = int(os.environ["SUBMITIT_LOCAL_NTASKS"]) + self.command = shlex.split(os.environ["SUBMITIT_LOCAL_COMMAND"]) + self.timeout_s = int(os.environ["SUBMITIT_LOCAL_TIMEOUT_S"]) + self.signal_delay_s = int(os.environ["SUBMITIT_LOCAL_SIGNAL_DELAY_S"]) + self.stderr_to_stdout = bool(os.environ["SUBMITIT_STDERR_TO_STDOUT"]) + self.tasks: List[subprocess.Popen] = [] # type: ignore + self.stdouts: List[IO[Any]] = [] + self.stderrs: List[IO[Any]] = [] + self.pid = str(os.getpid()) + self.folder = Path(folder) + signal.signal(signal.SIGTERM, self._forward_signal) # type: ignore + + def _forward_signal(self, signum: signal.Signals, *args: Any) -> None: # pylint:disable=unused-argument + for task in self.tasks: + try: + task.send_signal( + signum + ) # sending kill signal to make sure everything finishes + except Exception: + pass + + def start_tasks(self) -> None: + self.folder.mkdir(exist_ok=True) + paths = [utils.JobPaths(self.folder, self.pid, k) for k in range(self.ntasks)] + self.stdouts = [p.stdout.open("a") for p in paths] + self.stderrs = ( + self.stdouts + if self.stderr_to_stdout + else [p.stderr.open("a") for p in paths] + ) + for k in range(self.ntasks): + env = dict(os.environ) + env.update( + SUBMITIT_LOCAL_LOCALID=str(k), + SUBMITIT_LOCAL_GLOBALID=str(k), + SUBMITIT_LOCAL_JOB_ID=self.pid, + ) + self.tasks.append( + subprocess.Popen( # pylint: disable=consider-using-with + self.command, + shell=False, + env=env, + stderr=self.stderrs[k], + stdout=self.stdouts[k], + encoding="utf-8", + ) + ) + + def kill_tasks(self) -> None: + # try and be progressive in deletion... + for sig in [signal.SIGINT, signal.SIGKILL]: + self._forward_signal(sig) + # if one is still alive after sigterm and sigint, try sigkill after 1s + if sig == signal.SIGINT and any(t.poll() is None for t in self.tasks): + time.sleep(0.001) + if any(t.poll() is None for t in self.tasks): + time.sleep(1.0) # wait a bit more + self.tasks = [] + files = self.stdouts + self.stderrs + self.stdouts, self.stderrs = [], [] # remove all instance references + for f in files: + f.close() + + def wait(self, freq: int = 24) -> Sequence[Optional[int]]: + """Waits for all tasks to finish or to time-out. + + Returns + ------- + Sequence[Optional[int]]: + Exit codes of each task. + Some tasks might still have not exited, but they will have received the "timed-out" signal. + """ + assert self.tasks, "Nothing to do!" + timeout = freq * self.timeout_s + almost_timeout = freq * (self.timeout_s - self.signal_delay_s) + + # safer to keep a for loop :) + for step in range(timeout): + exit_codes = [t.poll() for t in self.tasks] + if all(e is not None for e in exit_codes): + return exit_codes + + if step == almost_timeout: + self._forward_signal(LocalJobEnvironment._usr_sig()) + + time.sleep(1.0 / freq) + return [t.poll() for t in self.tasks] + + def run(self, max_retry: int = 6) -> None: + # max_retry is a safety measure, the submission also have a timeout_countdown, + # and will fail if it times out too many times. + for _ in range(max_retry): + try: + self.start_tasks() + exit_codes = self.wait() + requeue = any(e == LOCAL_REQUEUE_RETURN_CODE for e in exit_codes) + if not requeue: + break + finally: + self.kill_tasks() diff --git a/src/submitit/local/test_debug.py b/src/submitit/local/test_debug.py new file mode 100644 index 0000000..bc7c330 --- /dev/null +++ b/src/submitit/local/test_debug.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import functools +import os +from pathlib import Path +from typing import Any, Tuple + +import pytest + +from ..core import utils +from ..core.core import Job +from ..core.job_environment import JobEnvironment +from . import debug + + +class CheckFunction: + """Function used for checking that computations are correct""" + + def __init__(self, n: int) -> None: + self.data1 = list(range(n)) + self.data2 = list(range(10, 10 + n)) + + def __call__(self, x: float, y: float) -> float: + assert x in self.data1 + assert y in self.data2 + return x + y + + +def test_debug_job(tmp_path: Path) -> None: + def func(p: int) -> int: + return 2 * p + + executor = debug.DebugExecutor(tmp_path) + job = executor.submit(func, 4) + assert job.result() == 8 + with executor.batch(): + job2 = executor.submit(func, 5) + assert job2.result() == 10 + # Check that job results are cached. + job2.submission().function = None # type: ignore + assert job2.result() == 10 + + +def test_debug_map_array(tmp_path: Path) -> None: + g = CheckFunction(5) + executor = debug.DebugExecutor(tmp_path) + jobs = executor.map_array(g, g.data1, g.data2) + print(type(jobs[0])) + print(jobs) + assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] + + +def test_debug_submit_array(tmp_path: Path) -> None: + g = CheckFunction(5) + executor = debug.DebugExecutor(tmp_path) + fns = [functools.partial(g, x, y) for x, y in zip(g.data1, g.data2)] + jobs = executor.submit_array(fns) + assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] + + +def test_debug_error(tmp_path: Path) -> None: + def failing_job() -> None: + raise Exception("Failed on purpose") + + executor = debug.DebugExecutor(tmp_path) + job = executor.submit(failing_job) + exception = job.exception() + assert isinstance(exception, Exception) + message = exception.args[0] + assert "Failed on purpose" in message + + +def f_42() -> int: + return 42 + + +def test_debug_triggered(tmp_path: Path) -> None: + def get_result(job: Job) -> Tuple[bool, Any]: + assert isinstance(job, debug.DebugJob) + return (job._submission._done, job._submission._result) + + executor = debug.DebugExecutor(tmp_path) + for trigger in ("wait", "done", "exception", "results"): + job = executor.submit(f_42) + assert job.state == "QUEUED" + assert get_result(job) == (False, None) + getattr(job, trigger)() + assert job.state == "DONE" + assert get_result(job) == (True, 42) + + +def test_cancel(tmp_path: Path) -> None: + executor = debug.DebugExecutor(tmp_path) + job = executor.submit(f_42) + assert job.state == "QUEUED" + job.cancel() + assert job.state == "CANCELLED" + with pytest.raises(utils.UncompletedJobError, match="was cancelled"): + job.result() + + +def test_job_environment(tmp_path: Path) -> None: + executor = debug.DebugExecutor(tmp_path) + + def use_env(): + env = JobEnvironment() + assert env.num_nodes == 1 + assert env.num_tasks == 1 + assert env.node == 0 + assert env.global_rank == 0 + assert env.local_rank == 0 + assert "DEBUG" in env.job_id + + job = executor.submit(use_env) + job.result() + # Check that we clean up the env after us. + assert "SUBMITIT_DEBUG_JOB_ID" not in os.environ diff --git a/src/submitit/local/test_local.py b/src/submitit/local/test_local.py new file mode 100644 index 0000000..8667f45 --- /dev/null +++ b/src/submitit/local/test_local.py @@ -0,0 +1,251 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import functools +import os +import re +import signal +import sys +import time +from pathlib import Path + +import pytest + +from .. import helpers +from ..core import job_environment, test_core, utils +from . import local, test_debug + + +def test_local_job(tmp_path: Path) -> None: + def func(p: int) -> int: + job_env = job_environment.JobEnvironment() + return p * job_env.local_rank + + executor = local.LocalExecutor(tmp_path) + executor.update_parameters(tasks_per_node=3, nodes=1) + job1 = executor.submit(func, 1) + + executor.update_parameters(tasks_per_node=1) + + with executor.batch(): + with pytest.raises(RuntimeError, match="with executor.batch"): + executor.update_parameters(tasks_per_node=1) + job2 = executor.submit(func, 2) + assert job1.results() == [0, 1, 2] + assert job1.task(1).result() == 1 + assert job1.task(2).result() == 2 + assert job1.task(2).result() == 2 + assert job1.exception() is None + assert job1.done() + + with pytest.raises(ValueError, match="must be between 0 and 2"): + job1.task(4).result() + + assert job2.results() == [0] + assert job2.task(0).result() == 0 + # single task job is a regular job + assert job2.task(0) is job2 + assert job2.done() + + +def test_local_map_array(tmp_path: Path) -> None: + g = test_debug.CheckFunction(5) + executor = local.LocalExecutor(tmp_path) + jobs = executor.map_array(g, g.data1, g.data2) + assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] + + +def test_local_submit_array(tmp_path: Path) -> None: + g = test_debug.CheckFunction(5) + executor = local.LocalExecutor(tmp_path) + fns = [functools.partial(g, x, y) for x, y in zip(g.data1, g.data2)] + jobs = executor.submit_array(fns) + assert list(map(g, g.data1, g.data2)) == [j.result() for j in jobs] + + +def test_local_error(tmp_path: Path) -> None: + def failing_job() -> None: + raise Exception("Failed on purpose") + + executor = local.LocalExecutor(tmp_path) + job = executor.submit(failing_job) + exception = job.exception() + assert isinstance(exception, utils.FailedJobError) + traceback = exception.args[0] + assert "Traceback" in traceback + assert "Failed on purpose" in traceback + + +def test_pickle_output_from_main(tmp_path: Path) -> None: + class MyClass: + pass + + executor = local.LocalExecutor(tmp_path) + job = executor.submit(MyClass.__call__) + assert isinstance(job.result(), MyClass) + + +def test_get_first_task_error(tmp_path: Path) -> None: + def flaky() -> None: + job_env = job_environment.JobEnvironment() + if job_env.local_rank > 0: + raise Exception(f"Failed on purpose: {job_env.local_rank}") + + executor = local.LocalExecutor(tmp_path) + executor.update_parameters(tasks_per_node=3, nodes=1) + job = executor.submit(flaky) + exception = job.exception() + assert isinstance(exception, utils.FailedJobError) + traceback = exception.args[0] + assert "Traceback" in traceback + assert "Failed on purpose: 1" in traceback + + +def test_stdout(tmp_path: Path) -> None: + def hello() -> None: + job_env = job_environment.JobEnvironment() + print("hello from", job_env.local_rank) + print("bye from", job_env.local_rank, file=sys.stderr) + + executor = local.LocalExecutor(tmp_path) + executor.update_parameters(tasks_per_node=2, nodes=1) + job = executor.submit(hello) + + job.wait() + stdout = job.stdout() + assert stdout is not None + assert "hello from 0\n" in stdout + assert "hello from 1\n" in stdout + + stderr = job.stderr() + assert stderr is not None + assert "bye from 0\n" in stderr + assert "bye from 1\n" in stderr + + +def test_killed(tmp_path: Path) -> None: + def failing_job() -> None: + time.sleep(120) + raise Exception("Failed on purpose") + + executor = local.LocalExecutor(tmp_path) + job = executor.submit(failing_job) + assert job.state == "RUNNING" + job._process.send_signal(signal.SIGKILL) # type: ignore + time.sleep(1) + assert job.state == "INTERRUPTED" + + +@pytest.mark.skipif(not os.environ.get("SUBMITIT_SLOW_TESTS", False), reason="slow") # type: ignore +def test_long_running_job(tmp_path: Path) -> None: + def f(x: int, y: int, sleep: int = 120) -> int: + time.sleep(sleep) + return x + y + + executor = local.LocalExecutor(tmp_path) + executor.update_parameters(timeout_min=5) + job = executor.submit(f, 40, 2) + assert job.result() == 42 + + +def test_requeuing(tmp_path: Path) -> None: + func = helpers.FunctionSequence(verbose=True) + for x in range(20): + func.add(test_core.do_nothing, x=x, sleep=1) + executor = local.LocalExecutor(tmp_path, max_num_timeout=1) + executor.update_parameters(timeout_min=3 / 60, signal_delay_s=1) + job = executor.submit(func) + job.wait() + stdout = job.stdout() + assert stdout is not None + match = re.search(r"Starting from [123]/20", stdout) + assert match, f"Should have resumed from a checkpoint:\n{stdout}" + assert "timed-out too many times" in stdout, f"Unexpected stdout:\n{stdout}" + assert "(0 remaining timeouts)" in stdout, f"Unexpected stdout:\n{stdout}" + + +def test_custom_checkpoint(tmp_path: Path) -> None: + class Slacker(helpers.Checkpointable): + def __call__(self, slack: bool = True): + if slack: + print("Slacking", flush=True) + time.sleep(10) + raise Exception("I really don't want to work") + print("Working hard", flush=True) + return "worked hard" + + def __submitit_checkpoint__(self, slack: bool = True): + if slack: + print( + "Interrupted while slacking. I won't slack next time.", flush=True + ) + return utils.DelayedSubmission(self, slack=False) + + executor = local.LocalExecutor(tmp_path, max_num_timeout=1) + executor.update_parameters(timeout_min=2 / 60, signal_delay_s=1) + job = executor.submit(Slacker(True)) + job.wait() + stdout = job.stdout() + assert stdout + assert "I won't slack next time." in stdout + + +def test_make_subprocess(tmp_path: Path) -> None: + process = local.start_controller( + tmp_path, + "python -c 'import os;print(os.environ[\"SUBMITIT_LOCAL_JOB_ID\"])'", + timeout_min=1, + ) + paths = utils.JobPaths(tmp_path, str(process.pid), 0) + pg = process.pid + process.wait() + stdout = paths.stdout.read_text() + stderr = paths.stderr.read_text() + assert ( + stdout and int(stdout.strip()) == pg + ), f"PID link is broken (stderr: {stderr})" + + +def test_cancel(tmp_path: Path) -> None: + executor = local.LocalExecutor(tmp_path) + job = executor.submit(time.sleep, 10) + assert job.state == "RUNNING" + job.cancel() + time.sleep(0.1) + # Note: with local job we don't have a precise status. + assert job.state == "INTERRUPTED" + + job = executor.submit(time.sleep, 10) + process = job._process # type: ignore + job.cancel_at_deletion() + assert job.state == "RUNNING" + assert process.poll() is None + del job + time.sleep(0.1) + assert process.poll() == -2 + + +def f66(x: int, y: int = 0) -> int: # pylint: disable=unused-argument + return 66 + + +def test_load_submission(tmp_path: Path) -> None: + """Check we can load submission just from a path and job id.""" + executor = local.LocalExecutor(tmp_path) + job = executor.submit(f66, 67, y=68) + + submission = local.LocalJob(tmp_path, job.job_id).submission() + # It's important that f66 isn't a local function for the equality to work + assert submission.function is f66 + assert submission.args == (67,) + assert submission.kwargs == {"y": 68} + # Loading submission doesn't evaluate them. + assert submission._result is None + + +def test_weird_dir(weird_tmp_path: Path) -> None: + executor = local.LocalExecutor(weird_tmp_path / "%j") + executor.submit(f66, 67, 68).result() diff --git a/src/submitit/py.typed b/src/submitit/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/submitit/slurm/__init__.py b/src/submitit/slurm/__init__.py new file mode 100644 index 0000000..602d268 --- /dev/null +++ b/src/submitit/slurm/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# diff --git a/src/submitit/slurm/_sbatch_test_record.txt b/src/submitit/slurm/_sbatch_test_record.txt new file mode 100644 index 0000000..e0d01a5 --- /dev/null +++ b/src/submitit/slurm/_sbatch_test_record.txt @@ -0,0 +1,18 @@ +#!/bin/bash + +# Parameters +#SBATCH --blublu=12 +#SBATCH --error=/tmp/%j_0_log.err +#SBATCH --exclusive +#SBATCH --job-name=submitit +#SBATCH --nodes=1 +#SBATCH --open-mode=append +#SBATCH --output=/tmp/%j_0_log.out +#SBATCH --partition=learnfair +#SBATCH --signal=USR2@90 +#SBATCH --time=5 +#SBATCH --wckey=submitit + +# command +export SUBMITIT_EXECUTOR=slurm +srun --unbuffered --output /tmp/%j_%t_log.out --error /tmp/%j_%t_log.err -vv --cpu-bind none blublu bar diff --git a/src/submitit/slurm/slurm.py b/src/submitit/slurm/slurm.py new file mode 100644 index 0000000..6feff93 --- /dev/null +++ b/src/submitit/slurm/slurm.py @@ -0,0 +1,554 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import functools +import inspect +import os +import re +import shlex +import shutil +import subprocess +import sys +import typing as tp +import uuid +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from ..core import core, job_environment, logger, utils + + +def read_job_id(job_id: str) -> tp.List[Tuple[str, ...]]: + """Reads formated job id and returns a tuple with format: + (main_id, [array_index, [final_array_index]) + """ + pattern = r"(?P\d+)_\[(?P(\d+(-\d+)?(,)?)+)(\%\d+)?\]" + match = re.search(pattern, job_id) + if match is not None: + main = match.group("main_id") + array_ranges = match.group("arrays").split(",") + return [tuple([main] + array_range.split("-")) for array_range in array_ranges] + else: + main_id, *array_id = job_id.split("_", 1) + if not array_id: + return [(main_id,)] + # there is an array + array_num = str( + int(array_id[0]) + ) # trying to cast to int to make sure we understand + return [(main_id, array_num)] + + +class SlurmInfoWatcher(core.InfoWatcher): + def _make_command(self) -> Optional[List[str]]: + # asking for array id will return all status + # on the other end, asking for each and every one of them individually takes a huge amount of time + to_check = {x.split("_")[0] for x in self._registered - self._finished} + if not to_check: + return None + command = ["sacct", "-o", "JobID,State,NodeList", "--parsable2"] + for jid in to_check: + command.extend(["-j", str(jid)]) + return command + + def get_state(self, job_id: str, mode: str = "standard") -> str: + """Returns the state of the job + State of finished jobs are cached (use watcher.clear() to remove all cache) + + Parameters + ---------- + job_id: int + id of the job on the cluster + mode: str + one of "force" (forces a call), "standard" (calls regularly) or "cache" (does not call) + """ + info = self.get_info(job_id, mode=mode) + return info.get("State") or "UNKNOWN" + + def read_info(self, string: Union[bytes, str]) -> Dict[str, Dict[str, str]]: + """Reads the output of sacct and returns a dictionary containing main information""" + if not isinstance(string, str): + string = string.decode() + lines = string.splitlines() + if len(lines) < 2: + return {} # one job id does not exist (yet) + names = lines[0].split("|") + # read all lines + all_stats: Dict[str, Dict[str, str]] = {} + for line in lines[1:]: + stats = {x: y.strip() for x, y in zip(names, line.split("|"))} + job_id = stats["JobID"] + if not job_id or "." in job_id: + continue + try: + multi_split_job_id = read_job_id(job_id) + except Exception as e: + # Array id are sometimes displayed with weird chars + warnings.warn( + f"Could not interpret {job_id} correctly (please open an issue):\n{e}", + DeprecationWarning, + ) + continue + for split_job_id in multi_split_job_id: + all_stats["_".join(split_job_id[:2])] = ( + stats # this works for simple jobs, or job array unique instance + ) + # then, deal with ranges: + if len(split_job_id) >= 3: + for index in range(int(split_job_id[1]), int(split_job_id[2]) + 1): + all_stats[f"{split_job_id[0]}_{index}"] = stats + return all_stats + + +class SlurmJob(core.Job[core.R]): + _cancel_command = "scancel" + watcher = SlurmInfoWatcher(delay_s=600) + + def _interrupt(self, timeout: bool = False) -> None: + """Sends preemption or timeout signal to the job (for testing purpose) + + Parameter + --------- + timeout: bool + Whether to trigger a job time-out (if False, it triggers preemption) + """ + cmd = ["scancel", self.job_id, "--signal"] + # in case of preemption, SIGTERM is sent first + if not timeout: + subprocess.check_call(cmd + ["SIGTERM"]) + subprocess.check_call(cmd + [SlurmJobEnvironment.USR_SIG]) + + +class SlurmParseException(Exception): + pass + + +def _expand_id_suffix(suffix_parts: str) -> List[str]: + """Parse the a suffix formatted like "1-3,5,8" into + the list of numeric values 1,2,3,5,8. + """ + suffixes = [] + for suffix_part in suffix_parts.split(","): + if "-" in suffix_part: + low, high = suffix_part.split("-") + int_length = len(low) + for num in range(int(low), int(high) + 1): + suffixes.append(f"{num:0{int_length}}") + else: + suffixes.append(suffix_part) + return suffixes + + +def _parse_node_group(node_list: str, pos: int, parsed: List[str]) -> int: + """Parse a node group of the form PREFIX[1-3,5,8] and return + the position in the string at which the parsing stopped + """ + prefixes = [""] + while pos < len(node_list): + c = node_list[pos] + if c == ",": + parsed.extend(prefixes) + return pos + 1 + if c == "[": + last_pos = node_list.index("]", pos) + suffixes = _expand_id_suffix(node_list[pos + 1 : last_pos]) + prefixes = [prefix + suffix for prefix in prefixes for suffix in suffixes] + pos = last_pos + 1 + else: + for i, prefix in enumerate(prefixes): + prefixes[i] = prefix + c + pos += 1 + parsed.extend(prefixes) + return pos + + +def _parse_node_list(node_list: str): + try: + pos = 0 + parsed: List[str] = [] + while pos < len(node_list): + pos = _parse_node_group(node_list, pos, parsed) + return parsed + except ValueError as e: + raise SlurmParseException( + f"Unrecognized format for SLURM_JOB_NODELIST: '{node_list}'", e + ) from e + + +class SlurmJobEnvironment(job_environment.JobEnvironment): + _env = { + "job_id": "SLURM_JOB_ID", + "num_tasks": "SLURM_NTASKS", + "num_nodes": "SLURM_JOB_NUM_NODES", + "node": "SLURM_NODEID", + "nodes": "SLURM_JOB_NODELIST", + "global_rank": "SLURM_PROCID", + "local_rank": "SLURM_LOCALID", + "array_job_id": "SLURM_ARRAY_JOB_ID", + "array_task_id": "SLURM_ARRAY_TASK_ID", + } + + def _requeue(self, countdown: int) -> None: + jid = self.job_id + subprocess.check_call(["scontrol", "requeue", jid], timeout=60) + logger.get_logger().info(f"Requeued job {jid} ({countdown} remaining timeouts)") + + @property + def hostnames(self) -> List[str]: + """Parse the content of the "SLURM_JOB_NODELIST" environment variable, + which gives access to the list of hostnames that are part of the current job. + + In SLURM, the node list is formatted NODE_GROUP_1,NODE_GROUP_2,...,NODE_GROUP_N + where each node group is formatted as: PREFIX[1-3,5,8] to define the hosts: + [PREFIX1, PREFIX2, PREFIX3, PREFIX5, PREFIX8]. + + Link: https://hpcc.umd.edu/hpcc/help/slurmenv.html + """ + + node_list = os.environ.get(self._env["nodes"], "") + if not node_list: + return [self.hostname] + return _parse_node_list(node_list) + + +class SlurmExecutor(core.PicklingExecutor): + """Slurm job executor + This class is used to hold the parameters to run a job on slurm. + In practice, it will create a batch file in the specified directory for each job, + and pickle the task function and parameters. At completion, the job will also pickle + the output. Logs are also dumped in the same directory. + + Parameters + ---------- + folder: Path/str + folder for storing job submission/output and logs. + max_num_timeout: int + Maximum number of time the job can be requeued after timeout (if + the instance is derived from helpers.Checkpointable) + + Note + ---- + - be aware that the log/output folder will be full of logs and pickled objects very fast, + it may need cleaning. + - the folder needs to point to a directory shared through the cluster. This is typically + not the case for your tmp! If you try to use it, slurm will fail silently (since it + will not even be able to log stderr. + - use update_parameters to specify custom parameters (n_gpus etc...). If you + input erroneous parameters, an error will print all parameters available for you. + """ + + job_class = SlurmJob + + def __init__(self, folder: Union[Path, str], max_num_timeout: int = 3) -> None: + super().__init__(folder, max_num_timeout) + if not self.affinity() > 0: + raise RuntimeError( + 'Could not detect "srun", are you indeed on a slurm cluster?' + ) + + @classmethod + def _equivalence_dict(cls) -> core.EquivalenceDict: + return { + "name": "job_name", + "timeout_min": "time", + "mem_gb": "mem", + "nodes": "nodes", + "cpus_per_task": "cpus_per_task", + "gpus_per_node": "gpus_per_node", + "tasks_per_node": "ntasks_per_node", + } + + @classmethod + def _valid_parameters(cls) -> Set[str]: + """Parameters that can be set through update_parameters""" + return set(_get_default_parameters()) + + def _convert_parameters(self, params: Dict[str, Any]) -> Dict[str, Any]: + params = super()._convert_parameters(params) + # replace type in some cases + if "mem" in params: + params["mem"] = _convert_mem(params["mem"]) + return params + + def _internal_update_parameters(self, **kwargs: Any) -> None: + """Updates sbatch submission file parameters + + Parameters + ---------- + See slurm documentation for most parameters. + Most useful parameters are: time, mem, gpus_per_node, cpus_per_task, partition + Below are the parameters that differ from slurm documentation: + + signal_delay_s: int + delay between the kill signal and the actual kill of the slurm job. + setup: list + a list of command to run in sbatch befure running srun + array_parallelism: int + number of map tasks that will be executed in parallel + + Raises + ------ + ValueError + In case an erroneous keyword argument is added, a list of all eligible parameters + is printed, with their default values + + Note + ---- + Best practice (as far as Quip is concerned): cpus_per_task=2x (number of data workers + gpus_per_task) + You can use cpus_per_gpu=2 (requires using gpus_per_task and not gpus_per_node) + """ + defaults = _get_default_parameters() + in_valid_parameters = sorted(set(kwargs) - set(defaults)) + if in_valid_parameters: + string = "\n - ".join( + f"{x} (default: {repr(y)})" for x, y in sorted(defaults.items()) + ) + raise ValueError( + f"Unavailable parameter(s): {in_valid_parameters}\nValid parameters are:\n - {string}" + ) + # check that new parameters are correct + _make_sbatch_string(command="nothing to do", folder=self.folder, **kwargs) + super()._internal_update_parameters(**kwargs) + + def _internal_process_submissions( + self, delayed_submissions: tp.List[utils.DelayedSubmission] + ) -> tp.List[core.Job[tp.Any]]: + if len(delayed_submissions) == 1: + return super()._internal_process_submissions(delayed_submissions) + # array + folder = utils.JobPaths.get_first_id_independent_folder(self.folder) + folder.mkdir(parents=True, exist_ok=True) + timeout_min = self.parameters.get("time", 5) + pickle_paths = [] + for d in delayed_submissions: + pickle_path = folder / f"{uuid.uuid4().hex}.pkl" + d.set_timeout(timeout_min, self.max_num_timeout) + d.dump(pickle_path) + pickle_paths.append(pickle_path) + n = len(delayed_submissions) + # Make a copy of the executor, since we don't want other jobs to be + # scheduled as arrays. + array_ex = SlurmExecutor(self.folder, self.max_num_timeout) + array_ex.update_parameters(**self.parameters) + array_ex.parameters["map_count"] = n + self._throttle() + + first_job: core.Job[tp.Any] = array_ex._submit_command( + self._submitit_command_str + ) + tasks_ids = list(range(first_job.num_tasks)) + jobs: List[core.Job[tp.Any]] = [ + SlurmJob( + folder=self.folder, job_id=f"{first_job.job_id}_{a}", tasks=tasks_ids + ) + for a in range(n) + ] + for job, pickle_path in zip(jobs, pickle_paths): + job.paths.move_temporary_file(pickle_path, "submitted_pickle") + return jobs + + @property + def _submitit_command_str(self) -> str: + return " ".join( + [ + shlex.quote(sys.executable), + "-u -m submitit.core._submit", + shlex.quote(str(self.folder)), + ] + ) + + def _make_submission_file_text(self, command: str, uid: str) -> str: + return _make_sbatch_string( + command=command, folder=self.folder, **self.parameters + ) + + def _num_tasks(self) -> int: + nodes: int = self.parameters.get("nodes", 1) + tasks_per_node: int = max(1, self.parameters.get("ntasks_per_node", 1)) + return nodes * tasks_per_node + + def _make_submission_command(self, submission_file_path: Path) -> List[str]: + return ["sbatch", str(submission_file_path)] + + @staticmethod + def _get_job_id_from_submission_command(string: Union[bytes, str]) -> str: + """Returns the job ID from the output of sbatch string""" + if not isinstance(string, str): + string = string.decode() + output = re.search(r"job (?P[0-9]+)", string) + if output is None: + raise utils.FailedSubmissionError( + f'Could not make sense of sbatch output "{string}"\n' + "Job instance will not be able to fetch status\n" + "(you may however set the job job_id manually if needed)" + ) + return output.group("id") + + @classmethod + def affinity(cls) -> int: + return -1 if shutil.which("srun") is None else 2 + + +@functools.lru_cache() +def _get_default_parameters() -> Dict[str, Any]: + """Parameters that can be set through update_parameters""" + specs = inspect.getfullargspec(_make_sbatch_string) + zipped = zip(specs.args[-len(specs.defaults) :], specs.defaults) # type: ignore + return { + key: val for key, val in zipped if key not in {"command", "folder", "map_count"} + } + + +# pylint: disable=too-many-arguments,unused-argument, too-many-locals +def _make_sbatch_string( + command: str, + folder: tp.Union[str, Path], + job_name: str = "submitit", + partition: tp.Optional[str] = None, + time: int = 5, + nodes: int = 1, + ntasks_per_node: tp.Optional[int] = None, + cpus_per_task: tp.Optional[int] = None, + cpus_per_gpu: tp.Optional[int] = None, + num_gpus: tp.Optional[int] = None, # legacy + gpus_per_node: tp.Optional[int] = None, + gpus_per_task: tp.Optional[int] = None, + qos: tp.Optional[str] = None, # quality of service + setup: tp.Optional[tp.List[str]] = None, + mem: tp.Optional[str] = None, + mem_per_gpu: tp.Optional[str] = None, + mem_per_cpu: tp.Optional[str] = None, + signal_delay_s: int = 90, + comment: tp.Optional[str] = None, + constraint: tp.Optional[str] = None, + exclude: tp.Optional[str] = None, + account: tp.Optional[str] = None, + gres: tp.Optional[str] = None, + exclusive: tp.Optional[tp.Union[bool, str]] = None, + array_parallelism: int = 256, + wckey: str = "submitit", + stderr_to_stdout: bool = False, + map_count: tp.Optional[int] = None, # used internally + additional_parameters: tp.Optional[tp.Dict[str, tp.Any]] = None, + srun_args: tp.Optional[tp.Iterable[str]] = None, +) -> str: + """Creates the content of an sbatch file with provided parameters + + Parameters + ---------- + See slurm sbatch documentation for most parameters: + https://slurm.schedmd.com/sbatch.html + + Below are the parameters that differ from slurm documentation: + + folder: str/Path + folder where print logs and error logs will be written + signal_delay_s: int + delay between the kill signal and the actual kill of the slurm job. + setup: list + a list of command to run in sbatch before running srun + map_size: int + number of simultaneous map/array jobs allowed + additional_parameters: dict + Forces any parameter to a given value in sbatch. This can be useful + to add parameters which are not currently available in submitit. + Eg: {"mail-user": "blublu@fb.com", "mail-type": "BEGIN"} + srun_args: List[str] + Add each argument in the list to the srun call + + Raises + ------ + ValueError + In case an erroneous keyword argument is added, a list of all eligible parameters + is printed, with their default values + """ + nonslurm = [ + "nonslurm", + "folder", + "command", + "map_count", + "array_parallelism", + "additional_parameters", + "setup", + "signal_delay_s", + "stderr_to_stdout", + "srun_args", + ] + parameters = { + k: v for k, v in locals().items() if v is not None and k not in nonslurm + } + # rename and reformat parameters + parameters["signal"] = f"{SlurmJobEnvironment.USR_SIG}@{signal_delay_s}" + if num_gpus is not None: + warnings.warn( + '"num_gpus" is deprecated, please use "gpus_per_node" instead (overwritting with num_gpus)' + ) + parameters["gpus_per_node"] = parameters.pop("num_gpus", 0) + if "cpus_per_gpu" in parameters and "gpus_per_task" not in parameters: + warnings.warn( + '"cpus_per_gpu" requires to set "gpus_per_task" to work (and not "gpus_per_node")' + ) + # add necessary parameters + paths = utils.JobPaths(folder=folder) + stdout = str(paths.stdout) + stderr = str(paths.stderr) + # Job arrays will write files in the form __ + if map_count is not None: + assert isinstance(map_count, int) and map_count + parameters["array"] = f"0-{map_count - 1}%{min(map_count, array_parallelism)}" + stdout = stdout.replace("%j", "%A_%a") + stderr = stderr.replace("%j", "%A_%a") + parameters["output"] = stdout.replace("%t", "0") + if not stderr_to_stdout: + parameters["error"] = stderr.replace("%t", "0") + parameters["open-mode"] = "append" + if additional_parameters is not None: + parameters.update(additional_parameters) + # now create + lines = ["#!/bin/bash", "", "# Parameters"] + for k in sorted(parameters): + lines.append(_as_sbatch_flag(k, parameters[k])) + # environment setup: + if setup is not None: + lines += ["", "# setup"] + setup + # commandline (this will run the function and args specified in the file provided as argument) + # We pass --output and --error here, because the SBATCH command doesn't work as expected with a filename pattern + stderr_flags = [] if stderr_to_stdout else ["--error", stderr] + if srun_args is None: + srun_args = [] + + srun_cmd = _shlex_join( + ["srun", "--unbuffered", "--output", stdout, *stderr_flags, *srun_args] + ) + lines += [ + "", + "# command", + "export SUBMITIT_EXECUTOR=slurm", + # The input "command" is supposed to be a valid shell command + " ".join((srun_cmd, command)), + "", + ] + return "\n".join(lines) + + +def _convert_mem(mem_gb: float) -> str: + if mem_gb == int(mem_gb): + return f"{int(mem_gb)}GB" + return f"{int(mem_gb * 1024)}MB" + + +def _as_sbatch_flag(key: str, value: tp.Any) -> str: + key = key.replace("_", "-") + if value is True: + return f"#SBATCH --{key}" + + value = shlex.quote(str(value)) + return f"#SBATCH --{key}={value}" + + +def _shlex_join(split_command: tp.List[str]) -> str: + """Same as shlex.join, but that was only added in Python 3.8""" + return " ".join(shlex.quote(arg) for arg in split_command) diff --git a/src/submitit/slurm/test_slurm.py b/src/submitit/slurm/test_slurm.py new file mode 100644 index 0000000..7979d56 --- /dev/null +++ b/src/submitit/slurm/test_slurm.py @@ -0,0 +1,556 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +import contextlib +import os +import signal +import subprocess +import typing as tp +from pathlib import Path +from unittest.mock import patch + +import pytest + +import submitit + +from .. import helpers +from ..core import job_environment, submission, test_core, utils +from ..core.core import Job +from . import slurm + + +def _mock_log_files(job: Job[tp.Any], prints: str = "", errors: str = "") -> None: + """Write fake log files""" + filepaths = [ + str(x).replace("%j", str(job.job_id)) + for x in [job.paths.stdout, job.paths.stderr] + ] + for filepath, msg in zip(filepaths, (prints, errors)): + with Path(filepath).open("w") as f: + f.write(msg) + + +@contextlib.contextmanager +def mocked_slurm() -> tp.Iterator[test_core.MockedSubprocess]: + mock = test_core.MockedSubprocess(known_cmds=["srun"]) + try: + with mock.context(): + yield mock + finally: + # Clear the state of the shared watcher + slurm.SlurmJob.watcher.clear() + + +def test_mocked_missing_state(tmp_path: Path) -> None: + with mocked_slurm() as mock: + mock.set_job_state("12", " ") + job: slurm.SlurmJob[None] = slurm.SlurmJob(tmp_path, "12") + assert job.state == "UNKNOWN" + job._interrupt(timeout=False) # check_call is bypassed by MockedSubprocess + + +def test_job_environment() -> None: + with mocked_slurm() as mock: + mock.set_job_state("12", "RUNNING") + with mock.job_context("12"): + assert job_environment.JobEnvironment().cluster == "slurm" + + +def test_slurm_job_mocked(tmp_path: Path) -> None: + with mocked_slurm() as mock: + executor = slurm.SlurmExecutor(folder=tmp_path) + job = executor.submit(test_core.do_nothing, 1, 2, blublu=3) + # First mock job always have id 12 + assert job.job_id == "12" + assert job.state == "RUNNING" + assert job.stdout() is None + _mock_log_files(job, errors="This is the error log\n", prints="hop") + job._results_timeout_s = 0 + with pytest.raises(utils.UncompletedJobError): + job._get_outcome_and_result() + _mock_log_files(job, errors="This is the error log\n", prints="hop") + + with mock.job_context(job.job_id): + submission.process_job(job.paths.folder) + assert job.result() == 12 + # logs + assert job.stdout() == "hop" + assert job.stderr() == "This is the error log\n" + assert ( + "_USELESS_TEST_ENV_VAR_" not in os.environ + ), "Test context manager seems to be failing" + + +@pytest.mark.parametrize("use_batch_api", (False, True)) # type: ignore +def test_slurm_job_array_mocked(use_batch_api: bool, tmp_path: Path) -> None: + n = 5 + with mocked_slurm() as mock: + executor = slurm.SlurmExecutor(folder=tmp_path) + executor.update_parameters(array_parallelism=3) + data1, data2 = range(n), range(10, 10 + n) + + def add(x: int, y: int) -> int: + assert x in data1 + assert y in data2 + return x + y + + jobs: tp.List[Job[int]] = [] + if use_batch_api: + with executor.batch(): + for d1, d2 in zip(data1, data2): + jobs.append(executor.submit(add, d1, d2)) + else: + jobs = executor.map_array(add, data1, data2) + array_id = jobs[0].job_id.split("_")[0] + assert [f"{array_id}_{a}" for a in range(n)] == [j.job_id for j in jobs] + + for job in jobs: + assert job.state == "RUNNING" + with mock.job_context(job.job_id): + submission.process_job(job.paths.folder) + # trying a slurm specific method + jobs[0]._interrupt(timeout=True) # type: ignore + assert list(map(add, data1, data2)) == [j.result() for j in jobs] + # check submission file + sbatch = Job(tmp_path, job_id=array_id).paths.submission_file.read_text() + array_line = [l.strip() for l in sbatch.splitlines() if "--array" in l] + assert array_line == ["#SBATCH --array=0-4%3"] + + +def test_slurm_error_mocked(tmp_path: Path) -> None: + with mocked_slurm() as mock: + executor = slurm.SlurmExecutor(folder=tmp_path) + executor.update_parameters( + time=24, gpus_per_node=0 + ) # just to cover the function + job = executor.submit(test_core.do_nothing, 1, 2, error=12) + with mock.job_context(job.job_id): + with pytest.raises(ValueError): + submission.process_job(job.paths.folder) + _mock_log_files(job, errors="This is the error log\n") + with pytest.raises(utils.FailedJobError): + job.result() + exception = job.exception() + assert isinstance(exception, utils.FailedJobError) + + +@contextlib.contextmanager +def mock_requeue(called_with: int = None, not_called: bool = False): + assert not_called or called_with is not None + requeue = patch( + "submitit.slurm.slurm.SlurmJobEnvironment._requeue", return_value=None + ) + with requeue as _patch: + try: + yield + finally: + if not_called: + _patch.assert_not_called() + else: + _patch.assert_called_with(called_with) + + +def get_signal_handler(job: Job) -> job_environment.SignalHandler: + env = slurm.SlurmJobEnvironment() + delayed = utils.DelayedSubmission.load(job.paths.submitted_pickle) + sig = job_environment.SignalHandler(env, job.paths, delayed) + return sig + + +def test_requeuing_checkpointable(tmp_path: Path, fast_forward_clock) -> None: + usr_sig = submitit.JobEnvironment._usr_sig() + fs0 = helpers.FunctionSequence() + fs0.add(test_core._three_time, 10) + assert isinstance(fs0, helpers.Checkpointable) + + # Start job with a 60 minutes timeout + with mocked_slurm(): + executor = slurm.SlurmExecutor(folder=tmp_path, max_num_timeout=1) + executor.update_parameters(time=60) + job = executor.submit(fs0) + + sig = get_signal_handler(job) + + fast_forward_clock(minutes=30) + # Preempt the job after 30 minutes, the job hasn't timeout. + with pytest.raises(SystemExit), mock_requeue(called_with=1): + sig.checkpoint_and_try_requeue(usr_sig) + + # Restart the job. + sig = get_signal_handler(job) + fast_forward_clock(minutes=50) + + # This time the job as timed out, + # but we have max_num_timeout=1, so we should requeue. + # We are a little bit under the requested timedout, but close enough + # to not consider this a preemption + with pytest.raises(SystemExit), mock_requeue(called_with=0): + sig.checkpoint_and_try_requeue(usr_sig) + + # Restart the job. + sig = get_signal_handler(job) + fast_forward_clock(minutes=55) + + # The job has already timed out twice, we should stop here. + usr_sig = slurm.SlurmJobEnvironment._usr_sig() + with mock_requeue(not_called=True), pytest.raises( + utils.UncompletedJobError, match="timed-out too many times." + ): + sig.checkpoint_and_try_requeue(usr_sig) + + +def test_requeuing_not_checkpointable(tmp_path: Path, fast_forward_clock) -> None: + usr_sig = submitit.JobEnvironment._usr_sig() + # Start job with a 60 minutes timeout + with mocked_slurm(): + executor = slurm.SlurmExecutor(folder=tmp_path, max_num_timeout=1) + executor.update_parameters(time=60) + job = executor.submit(test_core._three_time, 10) + + # simulate job start + sig = get_signal_handler(job) + fast_forward_clock(minutes=30) + + with mock_requeue(not_called=True): + sig.bypass(signal.Signals.SIGTERM) + + # Preempt the job after 30 minutes, the job hasn't timeout. + with pytest.raises(SystemExit), mock_requeue(called_with=1): + sig.checkpoint_and_try_requeue(usr_sig) + + # Restart the job from scratch + sig = get_signal_handler(job) + fast_forward_clock(minutes=50) + + # Wait 50 minutes, now the job as timed out. + with mock_requeue(not_called=True), pytest.raises( + utils.UncompletedJobError, match="timed-out and not checkpointable" + ): + sig.checkpoint_and_try_requeue(usr_sig) + + +def test_checkpoint_and_exit(tmp_path: Path) -> None: + usr_sig = submitit.JobEnvironment._usr_sig() + with mocked_slurm(): + executor = slurm.SlurmExecutor(folder=tmp_path, max_num_timeout=1) + executor.update_parameters(time=60) + job = executor.submit(test_core._three_time, 10) + + sig = get_signal_handler(job) + with pytest.raises(SystemExit), mock_requeue(not_called=True): + sig.checkpoint_and_exit(usr_sig) + + # checkpoint_and_exit doesn't modify timeout counters. + delayed = utils.DelayedSubmission.load(job.paths.submitted_pickle) + assert delayed._timeout_countdown == 1 + + +def test_make_sbatch_string() -> None: + string = slurm._make_sbatch_string( + command="blublu bar", + folder="/tmp", + partition="learnfair", + exclusive=True, + additional_parameters=dict(blublu=12), + srun_args=["-vv", "--cpu-bind", "none"], + ) + assert "partition" in string + assert "--command" not in string + assert "constraint" not in string + record_file = Path(__file__).parent / "_sbatch_test_record.txt" + if not record_file.exists(): + record_file.write_text(string) + recorded = record_file.read_text() + changes = [] + for k, (line1, line2) in enumerate(zip(string.splitlines(), recorded.splitlines())): + if line1 != line2: + changes.append(f'line #{k + 1}: "{line2}" -> "{line1}"') + if changes: + print(string) + print("# # # # #") + print(recorded) + message = ["Difference with reference file:"] + changes + message += [ + "", + "Delete the record file if this is normal:", + f"rm {record_file}", + ] + raise AssertionError("\n".join(message)) + + +def test_make_sbatch_string_gpu() -> None: + string = slurm._make_sbatch_string(command="blublu", folder="/tmp", gpus_per_node=2) + assert "--gpus-per-node=2" in string + + +def test_make_sbatch_stderr() -> None: + string = slurm._make_sbatch_string( + command="blublu", folder="/tmp", stderr_to_stdout=True + ) + assert "--error" not in string + + +def test_update_parameters(tmp_path: Path) -> None: + with mocked_slurm(): + executor = submitit.AutoExecutor(folder=tmp_path) + executor.update_parameters(mem_gb=3.5) + assert executor._executor.parameters["mem"] == "3584MB" + + +def test_update_parameters_error(tmp_path: Path) -> None: + with mocked_slurm(): + executor = slurm.SlurmExecutor(folder=tmp_path) + with pytest.raises(ValueError): + executor.update_parameters(blublu=12) + + +def test_read_info() -> None: + example = """JobID|State +5610980|RUNNING +5610980.ext+|RUNNING +5610980.0|RUNING +20956421_0|RUNNING +20956421_[2-4%25]|PENDING +""" + output = slurm.SlurmInfoWatcher().read_info(example) + assert output["5610980"] == {"JobID": "5610980", "State": "RUNNING"} + assert output["20956421_2"] == {"JobID": "20956421_[2-4%25]", "State": "PENDING"} + assert set(output) == { + "5610980", + "20956421_0", + "20956421_2", + "20956421_3", + "20956421_4", + } + + +@pytest.mark.parametrize( # type: ignore + "name,state", + [("12_0", "R"), ("12_1", "U"), ("12_2", "X"), ("12_3", "U"), ("12_4", "X")], +) +def test_read_info_array(name: str, state: str) -> None: + example = "JobID|State\n12_0|R\n12_[2,4-12]|X" + watcher = slurm.SlurmInfoWatcher() + for jobid in ["12_2", "12_4"]: + watcher.register_job(jobid) + output = watcher.read_info(example) + assert output.get(name, {}).get("State", "U") == state + + +@pytest.mark.parametrize( # type: ignore + "job_id,expected", + [ + ("12", [(12,)]), + ("12_0", [(12, 0)]), + ("20_[2-7%56]", [(20, 2, 7)]), + ("20_[2-7,12-17,22%56]", [(20, 2, 7), (20, 12, 17), (20, 22)]), + ("20_[0%1]", [(20, 0)]), + ], +) +def test_read_job_id( + job_id: str, expected: tp.List[tp.Tuple[tp.Union[int, str], ...]] +) -> None: + output = slurm.read_job_id(job_id) + assert output == [tuple(str(x) for x in group) for group in expected] + + +@pytest.mark.parametrize( # type: ignore + "string,expected", + [ + (b"Submitted batch job 5610208\n", "5610208"), + ("Submitted batch job 5610208\n", "5610208"), + ], +) +def test_get_id_from_submission_command(string: str, expected: str) -> None: + output = slurm.SlurmExecutor._get_job_id_from_submission_command(string) + assert output == expected + + +def test_get_id_from_submission_command_raise() -> None: + with pytest.raises(utils.FailedSubmissionError): + slurm.SlurmExecutor._get_job_id_from_submission_command(string=b"blublu") + + +def test_watcher() -> None: + with mocked_slurm() as mock: + watcher = slurm.SlurmInfoWatcher() + mock.set_job_state("12", "RUNNING") + assert watcher.num_calls == 0 + state = watcher.get_state(job_id="11") + assert set(watcher._info_dict.keys()) == {"12"} + assert watcher._registered == {"11"} + + assert state == "UNKNOWN" + mock.set_job_state("12", "FAILED") + state = watcher.get_state(job_id="12", mode="force") + assert state == "FAILED" + # TODO: this test is implementation specific. Not sure if we can rewrite it another way. + assert watcher._registered == {"11", "12"} + assert watcher._finished == {"12"} + + +def test_get_default_parameters() -> None: + defaults = slurm._get_default_parameters() + assert defaults["nodes"] == 1 + + +def test_name() -> None: + assert slurm.SlurmExecutor.name() == "slurm" + + +@contextlib.contextmanager +def with_slurm_job_nodelist(node_list: str) -> tp.Iterator[slurm.SlurmJobEnvironment]: + os.environ["SLURM_JOB_ID"] = "1" + os.environ["SLURM_JOB_NODELIST"] = node_list + yield slurm.SlurmJobEnvironment() + del os.environ["SLURM_JOB_NODELIST"] + del os.environ["SLURM_JOB_ID"] + + +def test_slurm_node_list() -> None: + with with_slurm_job_nodelist("compute-b24") as env: + assert ["compute-b24"] == env.hostnames + with with_slurm_job_nodelist("compute-a1,compute-b2") as env: + assert ["compute-a1", "compute-b2"] == env.hostnames + with with_slurm_job_nodelist("compute-b2[1,2]") as env: + assert ["compute-b21", "compute-b22"] == env.hostnames + with with_slurm_job_nodelist("compute-b2[011,022]") as env: + assert ["compute-b2011", "compute-b2022"] == env.hostnames + with with_slurm_job_nodelist("compute-b2[1-3]") as env: + assert ["compute-b21", "compute-b22", "compute-b23"] == env.hostnames + with with_slurm_job_nodelist("compute-b2[1-3,5,6,8]") as env: + assert [ + "compute-b21", + "compute-b22", + "compute-b23", + "compute-b25", + "compute-b26", + "compute-b28", + ] == env.hostnames + with with_slurm_job_nodelist("compute-b2[1-3,5-6,8]") as env: + assert [ + "compute-b21", + "compute-b22", + "compute-b23", + "compute-b25", + "compute-b26", + "compute-b28", + ] == env.hostnames + with with_slurm_job_nodelist("compute-b2[1-3,5-6,8],compute-a1") as env: + assert [ + "compute-b21", + "compute-b22", + "compute-b23", + "compute-b25", + "compute-b26", + "compute-b28", + "compute-a1", + ] == env.hostnames + with with_slurm_job_nodelist("compute[042,044]") as env: + assert ["compute042", "compute044"] == env.hostnames + with with_slurm_job_nodelist("compute[042-043,045,048-049]") as env: + assert [ + "compute042", + "compute043", + "compute045", + "compute048", + "compute049", + ] == env.hostnames + + +def test_slurm_node_list_online_documentation() -> None: + with with_slurm_job_nodelist("compute-b24-[1-3,5-9],compute-b25-[1,4,8]") as env: + assert [ + "compute-b24-1", + "compute-b24-2", + "compute-b24-3", + "compute-b24-5", + "compute-b24-6", + "compute-b24-7", + "compute-b24-8", + "compute-b24-9", + "compute-b25-1", + "compute-b25-4", + "compute-b25-8", + ] == env.hostnames + + +def test_slurm_invalid_parse() -> None: + with pytest.raises(slurm.SlurmParseException): + with with_slurm_job_nodelist("compute-b2[1-,4]") as env: + print(env.hostnames) + with pytest.raises(slurm.SlurmParseException): + with with_slurm_job_nodelist("compute-b2[1,2,compute-b3]") as env: + print(env.hostnames) + + +def test_slurm_missing_node_list() -> None: + with with_slurm_job_nodelist("") as env: + assert [env.hostname] == env.hostnames + + +def test_slurm_weird_dir(weird_tmp_path: Path) -> None: + if "\n" in weird_tmp_path.name: + pytest.skip("test doesn't support newline in 'weird_tmp_path'") + with mocked_slurm(): + executor = slurm.SlurmExecutor(folder=weird_tmp_path) + job = executor.submit(test_core.do_nothing, 1, 2, blublu=3) + + # Touch the ouputfiles + job.paths.stdout.write_text("") + job.paths.stderr.write_text("") + + # Try to read sbatch flags from the file like sbatch would do it. + sbatch_args = {} + for l in job.paths.submission_file.read_text().splitlines(): + if not l.startswith("#SBATCH"): + continue + if "=" not in l: + continue + key, val = l[len("#SBATCH") :].strip().split("=", 1) + sbatch_args[key] = val.replace("%j", job.job_id).replace("%t", "0") + + # We do not quote --output and --error values here, + # because we want to check if they have been properly quoted before. + subprocess.check_call("ls " + sbatch_args["--output"], shell=True) + subprocess.check_call("ls " + sbatch_args["--error"], shell=True) + + +@pytest.mark.parametrize("params", [{}, {"mem_gb": None}]) # type: ignore +def test_slurm_through_auto(params: tp.Dict[str, int], tmp_path: Path) -> None: + with mocked_slurm(): + executor = submitit.AutoExecutor(folder=tmp_path) + executor.update_parameters( + **params, slurm_additional_parameters={"mem_per_gpu": 12} + ) + job = executor.submit(test_core.do_nothing, 1, 2, blublu=3) + text = job.paths.submission_file.read_text() + mem_lines = [x for x in text.splitlines() if "#SBATCH --mem" in x] + assert len(mem_lines) == 1, f"Unexpected lines: {mem_lines}" + + +def test_slurm_job_no_stderr(tmp_path: Path) -> None: + def fail_silently(): + raise ValueError("Too bad") + + with mocked_slurm() as mock: + executor = slurm.SlurmExecutor(folder=tmp_path) + # Failed but no stderr + job = executor.submit(fail_silently) + _mock_log_files(job, prints="job is running ...\n") + job._results_timeout_s = 0 + with pytest.raises(utils.UncompletedJobError, match="job is running ..."): + job._get_outcome_and_result() + + # Failed but no stderr nor stdout + mock.set_job_state("13", "RUNNING") + job = executor.submit(fail_silently) + job._results_timeout_s = 0 + # Explicitly unlink stdout because submitit is writing there on startup + # job.paths.stdout.unlink() + with pytest.raises( + utils.UncompletedJobError, match="No output/error stream produced !" + ): + job._get_outcome_and_result() diff --git a/src/submitit/test_documentation.py b/src/submitit/test_documentation.py new file mode 100644 index 0000000..df0bf23 --- /dev/null +++ b/src/submitit/test_documentation.py @@ -0,0 +1,102 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import re +import typing as tp +from pathlib import Path + +import submitit + + +class MarkdownLink: + """Handle to a markdown link, for easy existence test and printing + (external links are not tested) + """ + + regex = re.compile(r"\[(?P.+?)\]\((?P\S+?)\)") + + def __init__(self, root: Path, file: Path, name: str, link: str) -> None: + self.root = root + self.file = file + self.name = name + self.link = link + + def exists(self) -> bool: + if self.link.startswith("http"): + # We don't check external urls. + return True + link = self.link.split("#")[0] + if not link: + return False + fullpath = self.root / self.file.parent / link + return fullpath.exists() + + def __repr__(self) -> str: + return f"[{self.link}]({self.name}) in file {self.file}" + + +def _get_root() -> Path: + root = Path(__file__).parent.parent.absolute() + assert (root / "pyproject.toml").exists(), f"Wrong root folder: {root}" + return root + + +def _get_markdown_files(root: Path) -> tp.List[Path]: + return [ + md + for pattern in ("*.md", "submitit/**/*.md", "docs/**/*.md") + for md in root.glob(pattern) + ] + + +def _get_all_markdown_links(root: Path, files: tp.List[Path]) -> tp.List[MarkdownLink]: + """Returns a list of all existing markdown links""" + pattern = MarkdownLink.regex + links = [] + for file in files: + for match in pattern.finditer(file.read_text()): + links.append( + MarkdownLink(root, file, match.group("name"), match.group("link")) + ) + return links + + +def test_assert_markdown_links_not_broken() -> None: + root = _get_root() + files = _get_markdown_files(root) + assert len(files) > 3 + + links = _get_all_markdown_links(root, files) + assert len(links) > 5, "There should be several hyperlinks!" + broken_links = [l for l in links if not l.exists()] + assert not broken_links + + +def _replace_relative_links(regex: tp.Match[str]) -> str: + """Converts relative links into links to master + so that links on Pypi long description are correct + """ + string: str = regex.group() + link = regex.group("link") + name = regex.group("name") + version = submitit.__version__ + if not link.startswith("http") and Path(link).exists(): + github_url = f"github.com/facebookincubator/submitit/blob/{version}" + string = f"[{name}](https://{github_url}/{link})" + return string + + +def expand_links(): + readme = _get_root() / "README.md" + assert readme.exists() + + desc = readme.read_text(encoding="utf-8") + desc = re.sub(MarkdownLink.regex, _replace_relative_links, desc) + readme.write_text(desc) + + +if __name__ == "__main__": + expand_links() diff --git a/src/submitit/test_helpers.py b/src/submitit/test_helpers.py new file mode 100644 index 0000000..26294da --- /dev/null +++ b/src/submitit/test_helpers.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import time +import typing as tp +from pathlib import Path + +import pytest + +from . import helpers +from .core import core, utils + + +def _three_time(x: int) -> int: + return 3 * x + + +requires_rsync = pytest.mark.skipif( + not helpers.RsyncSnapshot.available(), reason="Rsync is required for snapshotting" +) + + +def test_function_sequence_checkpoint(tmp_path: Path) -> None: + file = tmp_path / "test_funcseq.pkl" + fs0 = helpers.FunctionSequence(verbose=True) + fs0.add(_three_time, 4) + fs0.add(_three_time, 5) + assert len(fs0) == 2 + assert sum(x.done() for x in fs0) == 0 + utils.cloudpickle_dump(fs0, file) + fs1 = utils.pickle_load(file) + assert sum(x.done() for x in fs1) == 0 + assert fs1() == [12, 15] + assert sum(x.done() for x in fs1) == 2 + + +def test_as_completed(executor) -> None: + def f(x: float) -> float: + time.sleep(x) + return x + + # slow need to be > 1.5s otherwise it might finish before we start polling. + slow, fast = 1.5, 0.1 + # One slow job and two fast jobs. + jobs = executor.map_array(f, [slow, fast, fast]) + start = time.time() + finished_jobs = [] + for n, j in enumerate(helpers.as_completed(jobs, poll_frequency=0.1)): + elapsed = time.time() - start + if n < 2: + # we start getting result before the slow job finished. + assert elapsed < slow + finished_jobs.append(j) + # We get fast job results first, then result of the slow job. + assert [fast, fast, slow] == [j.result() for j in finished_jobs] + assert jobs[0] is finished_jobs[-1] + + +@requires_rsync +def test_snapshot(tmp_path: Path) -> None: + cwd = Path.cwd() + with helpers.RsyncSnapshot(tmp_path): + assert Path.cwd() == tmp_path + assert (tmp_path / "submitit/test_helpers.py").exists() + assert Path.cwd() == cwd + + +@requires_rsync +def test_snapshot_excludes(tmp_path: Path) -> None: + exclude = ["submitit/test_*"] + with helpers.RsyncSnapshot(snapshot_dir=tmp_path, exclude=exclude): + assert (tmp_path / "submitit/helpers.py").exists() + assert not (tmp_path / "submitit/test_helpers.py").exists() + + +@requires_rsync +def test_job_use_snapshot_cwd(executor, tmp_path: Path) -> None: + with helpers.RsyncSnapshot(snapshot_dir=tmp_path): + job = executor.submit(os.getcwd) + assert Path(job.result()) == tmp_path + + +@requires_rsync +def test_job_use_snapshot_modules(executor, tmp_path: Path) -> None: + with helpers.RsyncSnapshot(snapshot_dir=tmp_path): + + def submitit_file() -> Path: + # pylint: disable=import-outside-toplevel + import submitit + + return Path(submitit.__file__) + + job = executor.submit(submitit_file) + # Here we load the normal submitit + assert submitit_file() == Path(__file__).parent / "__init__.py" + # In the job we should import submitit from the snapshot dir + assert job.result() == tmp_path / "submitit/__init__.py" + + +class FakeInfoWatcherWithTimer(core.InfoWatcher): + # pylint: disable=abstract-method + def __init__(self, delay_s: int = 60, time_change: float = 0.02): + super().__init__(delay_s) + self.start_timer = time.time() + self.time_change = time_change + + def get_state(self, job_id: str, mode: str = "standard") -> str: + duration = time.time() - self.start_timer + if duration < self.time_change: + return "pending" + elif 2 * self.time_change > duration > self.time_change: + return "running" + if job_id == "failed": + return "failed" + return "done" + + +class FakeJobWithTimer(core.Job[core.R]): + watcher = FakeInfoWatcherWithTimer() + + +def test_monitor_jobs(tmp_path: Path) -> None: + job: FakeJobWithTimer[int] = FakeJobWithTimer(job_id="failed", folder=tmp_path) + job2: FakeJobWithTimer[int] = FakeJobWithTimer(job_id="succeeded", folder=tmp_path) + jobs = [job, job2] + helpers.monitor_jobs(jobs, 0.02, test_mode=True) + assert all(j for j in jobs if j.done()) + assert set(j for j in jobs if j.state.upper() == "FAILED") == {job} + + +def _get_env() -> tp.Dict[str, str]: + return { + x: y for x, y in os.environ.items() if x.startswith(("SLURM_", "SUBMITIT_")) + } + + +def test_clean_env() -> None: + base = _get_env() + with utils.environment_variables(SLURM_BLUBLU=12, SUBMITIT_BLUBLU=12): + assert len(_get_env()) == len(base) + 2 + with helpers.clean_env(): + assert not _get_env() + assert len(_get_env()) == len(base) + 2 + assert _get_env() == base diff --git a/src/submitit/test_pickle.py b/src/submitit/test_pickle.py new file mode 100644 index 0000000..0abed7c --- /dev/null +++ b/src/submitit/test_pickle.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import pickle +from weakref import ref + +import pytest + +from .local.debug import DebugExecutor +from .local.local import LocalExecutor + + +def job_with_weakref(ex): + class MyObject: + hello = "world" + + a = MyObject() + a_ref = ref(a) + assert a_ref() is a + + def f(a_ref): + a = a_ref() + assert a is not None + return a_ref().hello + + return ex.submit(f, ref(a)) + + +@pytest.mark.xfail(reason="'a' is GC-ed before we call the function") +def test_weakref_no_pickle(tmp_path): + ex = DebugExecutor(tmp_path) + assert job_with_weakref(ex).result() == "world" + + +@pytest.mark.xfail(reason="'ref(a)' can't be pickled") +def test_weakref_with_pickle(tmp_path): + ex = LocalExecutor(tmp_path) + assert job_with_weakref(ex).result() == "world" + + +def hello_fn() -> None: + print("hello world") + + +def test_nested_pickling(tmp_path): + def make_pickle() -> bytes: + return pickle.dumps(hello_fn) + + pkl = make_pickle() + assert bytes(__name__, "ascii") in pkl + assert b"hello_fn" in pkl + ex = LocalExecutor(tmp_path) + j = ex.submit(make_pickle) + assert j.result() == pkl + + +@pytest.mark.xfail(reason="Submitit changes __main__") +def test_submitit_respects_main(tmp_path): + # TODO: I think this is the root cause of issue #11 + # https://github.com/facebookincubator/submitit/issues/11 + # Some programs like pytorch-lightning are dependent on the value of __main__ + # See how `pdb` manage to restore the correct __main__: + # https://sourcegraph.com/github.com/python/cpython/-/blob/Lib/pdb.py#L1549 + # But maybe we could fix #11 by just using + # `from submitit.core.submission import submitit_main` + # as in https://github.com/facebookincubator/submitit/issues/11#issuecomment-713148952 + + def get_main() -> str: + # pylint: disable=import-outside-toplevel + import __main__ # type: ignore + + return getattr(__main__, "__file__", "") + + main = get_main() + ex = LocalExecutor(tmp_path) + j_main = ex.submit(get_main).result() + assert main == j_main