Skip to content

Commit

Permalink
Pre release changes for production (#59)
Browse files Browse the repository at this point in the history
* clean requirements

* rm taming deps

* isort, black

* mv lipips, license

* clean vq, fix path

* fix loss path, gitignore

* tested requirements pt13

* fix numpy req for python3.8, add tests

* fix name

* fix dep scipy 3.8 pt2

* add black test formatter
  • Loading branch information
benjaminaubin authored Jul 26, 2023
1 parent 4a3f0f5 commit e596332
Show file tree
Hide file tree
Showing 31 changed files with 641 additions and 127 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Run black
on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Install venv
run: |

This comment has been minimized.

Copy link
@Alexanderz34

Alexanderz34 Jul 29, 2023

run

sudo apt-get -y install python3.10-venv
- uses: psf/black@stable
with:
options: "--check --verbose -l88"
src: "./sgm ./scripts ./main.py"
26 changes: 26 additions & 0 deletions .github/workflows/test-build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Build package

on:
push:
pull_request:

jobs:
build:
name: Build
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.10"]
requirements-file: ["pt2", "pt13"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements/${{ matrix.requirements-file }}.txt
pip install .
9 changes: 7 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# extensions
*.egg-info
*.py[cod]

# envs
.pt13
.pt2
.pt2_2

# directories
/checkpoints
/dist
/outputs
build
/build
/src
23 changes: 17 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ This is assuming you have navigated to the `generative-models` root after clonin

```shell
# install required packages from pypi
python3 -m venv .pt1
source .pt1/bin/activate
pip3 install wheel
pip3 install -r requirements_pt13.txt
python3 -m venv .pt13
source .pt13/bin/activate
pip3 install -r requirements/pt13.txt
```

**PyTorch 2.0**
Expand All @@ -72,8 +71,20 @@ pip3 install -r requirements_pt13.txt
# install required packages from pypi
python3 -m venv .pt2
source .pt2/bin/activate
pip3 install wheel
pip3 install -r requirements_pt2.txt
pip3 install -r requirements/pt2.txt
```


#### 3. Install `sgm`

```shell
pip3 install .
```

#### 4. Install `sdata` for training

```shell
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
```

## Packaging
Expand Down
11 changes: 4 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,18 @@
import torch
import torchvision
import wandb
from PIL import Image
from matplotlib import pyplot as plt
from natsort import natsorted
from omegaconf import OmegaConf
from packaging import version
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_only

from sgm.util import (
exists,
instantiate_from_config,
isheatmap,
)
from sgm.util import exists, instantiate_from_config, isheatmap

MULTINODE_HACKS = True

Expand Down Expand Up @@ -910,11 +906,12 @@ def divein(*args, **kwargs):
trainer.test(model, data)
except RuntimeError as err:
if MULTINODE_HACKS:
import requests
import datetime
import os
import socket

import requests

device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
hostname = socket.gethostname()
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
Expand Down
40 changes: 40 additions & 0 deletions requirements/pt13.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
black==23.7.0
chardet>=5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
numpy>=1.24.4
omegaconf>=2.3.0
onnx<=1.12.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==1.8.5
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=1.25.0
tensorboardx==2.5.1
timm>=0.9.2
tokenizers==0.12.1
--extra-index-url https://download.pytorch.org/whl/cu117
torch==1.13.1+cu117
torchaudio==0.13.1
torchdata==0.5.1
torchmetrics>=1.0.1
torchvision==0.14.1+cu117
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0.post1
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers==0.0.16
39 changes: 39 additions & 0 deletions requirements/pt2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
black==23.7.0
chardet==5.1.0
clip @ git+https://github.com/openai/CLIP.git
einops>=0.6.1
fairscale>=0.4.13
fire>=0.5.0
fsspec>=2023.6.0
invisible-watermark>=0.2.0
kornia==0.6.9
matplotlib>=3.7.2
natsort>=8.4.0
ninja>=1.11.1
numpy>=1.24.4
omegaconf>=2.3.0
open-clip-torch>=2.20.0
opencv-python==4.6.0.66
pandas>=2.0.3
pillow>=9.5.0
pudb>=2022.1.3
pytorch-lightning==2.0.1
pyyaml>=6.0.1
scipy>=1.10.1
streamlit>=0.73.1
tensorboardx==2.6
timm>=0.9.2
tokenizers==0.12.1
torch>=2.0.1
torchaudio>=2.0.2
torchdata==0.6.1
torchmetrics>=1.0.1
torchvision>=0.15.2
tqdm>=4.65.0
transformers==4.19.1
triton==2.0.0
urllib3<1.27,>=1.25.4
wandb>=0.15.6
webdataset>=0.2.33
wheel>=0.41.0
xformers>=0.0.20
41 changes: 0 additions & 41 deletions requirements_pt13.txt

This file was deleted.

41 changes: 0 additions & 41 deletions requirements_pt2.txt

This file was deleted.

1 change: 1 addition & 0 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_lightning import seed_everything

from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering

Expand Down
19 changes: 9 additions & 10 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import math
import os
from typing import Union, List
from typing import List, Union

import math
import numpy as np
import streamlit as st
import torch
from PIL import Image
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import OmegaConf, ListConfig
from omegaconf import ListConfig, OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid
from safetensors.torch import load_file as load_safetensors

from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
EulerAncestralSampler,
DPMPP2SAncestralSampler,
DPMPP2MSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims
from sgm.util import instantiate_from_config
from sgm.util import append_dims, instantiate_from_config


class WatermarkEmbedder:
Expand Down
5 changes: 3 additions & 2 deletions scripts/util/detection/nsfw_and_watermark_dectection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import torch

import clip
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import clip

RESOURCES_ROOT = "scripts/util/detection/"

Expand Down
5 changes: 2 additions & 3 deletions sgm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .data import StableDataModuleFromConfig
from .models import AutoencodingEngine, DiffusionEngine
from .util import instantiate_from_config, get_configs_path
from .util import get_configs_path, instantiate_from_config

__version__ = "0.0.1"
__version__ = "0.1.0"
4 changes: 2 additions & 2 deletions sgm/data/cifar10.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class CIFAR10DataDictWrapper(Dataset):
Expand Down
4 changes: 2 additions & 2 deletions sgm/data/mnist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision
import pytorch_lightning as pl
from torchvision import transforms
import torchvision

This comment has been minimized.

Copy link
@Alexanderz34

Alexanderz34 Jul 29, 2023

or

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class MNISTDataDictWrapper(Dataset):
Expand Down
Loading

0 comments on commit e596332

Please sign in to comment.