Add Hugging Face links #43
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Tests | |
on: | |
- push | |
- pull_request | |
jobs: | |
# Installs the conda environment and trains METL | |
train: | |
name: Test METL training | |
runs-on: ${{ matrix.os }} | |
strategy: | |
fail-fast: false | |
matrix: | |
# Could also test on the beta M1 macOS runner | |
# https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories | |
# Disabled testing on Windows because there were too many Windows-specific PyTorch training issues | |
# Re-enable and debug these if there is user demand for Windows | |
os: | |
- macos-latest | |
- ubuntu-latest | |
# - windows-latest | |
steps: | |
- name: Checkout repository | |
uses: actions/checkout@v4 | |
# Can set up package caching later conda-incubator/setup-miniconda | |
- name: Install conda environment | |
uses: conda-incubator/setup-miniconda@v2 | |
with: | |
activate-environment: metl | |
environment-file: environment.yml | |
auto-activate-base: false | |
miniforge-variant: Mambaforge | |
miniforge-version: 'latest' | |
use-mamba: true | |
# Installs latest commit from main branch | |
- name: Install metl package from metl-pretrained repo | |
shell: bash --login {0} | |
run: pip install git+https://github.com/gitter-lab/metl-pretrained | |
# Log conda environment contents | |
- name: Log conda environment | |
shell: bash --login {0} | |
run: conda list | |
# Pretrain source model on GFP Rosetta dataset | |
- name: Pretrain source METL model | |
shell: bash --login {0} | |
run: python code/train_source_model.py @args/pretrain_avgfp_local.txt --max_epochs 5 --limit_train_batches 5 --limit_val_batches 5 --limit_test_batches 5 | |
# Finetune target model on GFP DMS dataset | |
- name: Finetune target METL model | |
shell: bash --login {0} | |
run: python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25 | |
# Load target model checkpoint and run inference on example variants | |
- name: Load and test target METL model | |
shell: bash --login {0} | |
# Log directory name is different on every run | |
run: | | |
cp=output/training_logs/*/checkpoints/epoch=49-step=50.ckpt | |
python code/convert_ckpt.py --ckpt_path $cp | |
cp=output/training_logs/*/checkpoints/*.pt | |
python code/tests.py --ckpt_path $cp --variants E3K,G102S_T36P,S203T,K207R_V10A,D19G,F25S,E113V --dataset avgfp |