Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
hearmeneigh committed Sep 4, 2023
1 parent c0d7a00 commit 22a65c9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
54 changes: 48 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# Dataset Rising

> Toolchain for training Stable Diffusion 1.x, Stable Diffusion 2.x, and Stable Diffusion XL models
> with custom image datasets.
> A toolchain for creating and training Stable Diffusion 1.x, Stable Diffusion 2.x, and Stable Diffusion XL models
> with custom datasets.
With this toolchain, you can:
* Crawl and download metadata and images from 'booru' style image boards
* Combine multiple sources of images (including your own custom sources)
* Build datasets based on your personal preferences
* Train Stable Diffusion models on your datasets
* Build datasets based on your personal preferences and filters
* Train Stable Diffusion models with your datasets
* Convert models into [Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/master) compatible models
* Use only the parts you need – the toolchain uses modular design, YAML configuration files, and JSONL data exchange formats
* Work with confidence that the tooling has been tested with Nvidia's RTX30x0, RTX40x0, A100, and H100 GPUs
* Work with confidence that the end-to-end tooling has been tested with Nvidia's RTX30x0, RTX40x0, A100, and H100 GPUs

## Requirements
* Python `>= 3.9.6`
* Docker `>= 24.0.0`

## Tested With
* MacOS 13 (M1)
* Ubuntu 22 (x86_64)


## Setting Up
Creates a virtual environment, installs packages, and sets up a MongoDB database on Docker.
Expand Down Expand Up @@ -137,7 +142,9 @@ python3 pick.py --selector ./examples/select/uncurated.yaml --output /tmp/uncura
### 5. Build a Dataset
After selecting the posts for the dataset, use `build` to download the images and build the dataset

By default, the build script prunes all tags that have fewer than 150 samples. To adjust this limit, use `--prune LIMIT`.
By default, the build script prunes all tags that have fewer than 100 samples. To adjust this limit, use `--min-posts-per-tag LIMIT`.

The build script will also prune all images that have fewer than 10 tags. To adjust this limit, use `--min-tags-per-post LIMIT`.

```bash
cd <dataset-rising>/dataset
Expand All @@ -153,6 +160,8 @@ python3 build.py \
```

### 6. Train a Model
Dataset Rising uses [Huggingface Accelerate](https://huggingface.co/docs/accelerate/index) to train Stable Diffusion models.

To train a model, you will need to pick the base model to start from. The `--base-model` can be any
[Diffusers](https://huggingface.co/docs/diffusers/index) compatible model, such as:

Expand All @@ -164,6 +173,8 @@ To train a model, you will need to pick the base model to start from. The `--bas
Note that your training results will be improved if you set `--image_width` and `--image_height` to match the
resolution the base model was trained with.

> This example does not scale to multiple GPUs. See the [Advanced Topics](#advanced-topics) section for multi-GPU training.
```bash
cd <dataset-rising>/train

Expand Down Expand Up @@ -233,6 +244,37 @@ python3 import.py ... # main sources and tags
python3 append.py --input /tmp/gelbooru-posts.jsonl --source gelbooru
```

### Multi-GPU Training
Multi-GPU training can be carried out with [Huggingface Accelerate](https://huggingface.co/docs/accelerate/package_reference/cli) library.

Before training, run `accelerate config` to set up your Multi-GPU environment.

```bash
cd <dataset-rising>/train

# set up environment
accelerate config

# run training
accelerate launch \
--multi_gpu \
train.py \
--pretrained-model-name-or-path 'stabilityai/stable-diffusion-xl-base-1.0' \
--dataset-name 'username/dataset-name' \
--resolution 1024 \
--maintain-aspect-ratio \
--reshuffle-tags \
--tag-separator ' ' \
--random-flip \
--train-batch-size 32 \
--learning-rate 4e-6 \
--use-ema \
--max-grad-norm 1 \
--checkpointing-steps 1000 \
--lr-scheduler constant \
--lr-warmup-steps 0
```

### Architecture

```mermaid
Expand Down
2 changes: 1 addition & 1 deletion dataset/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parser.add_argument('-s', '--source', metavar='FILE', type=str, action='append', help='Post JSONL file(s) to import', required=True)
parser.add_argument('-a', '--agent', metavar='AGENT', type=str, help='Unique user agent string (e.g. "mycrawler/1.0 (by myusername)")', required=True)
parser.add_argument('-e', '--export-tags', metavar='FILE', type=str, help='Export tag counts as a JSON file', required=False, default=None)
parser.add_argument('--min-posts-per-tag', metavar='COUNT', type=int, help='Minimum number of posts a tag must appear in to be included', required=False, default=150)
parser.add_argument('--min-posts-per-tag', metavar='COUNT', type=int, help='Minimum number of posts a tag must appear in to be included', required=False, default=100)
parser.add_argument('--min-tags-per-post', metavar='COUNT', type=int, help='Minimum number of tags in a post for the post to be included (counted after min-posts-per-tag limit has been applied)', required=False, default=10)
parser.add_argument('--prefilter', metavar='FILE', type=str, help='Prefilter YAML file', required=False, default='../examples/dataset/prefilter.yaml')
parser.add_argument('--image-width', metavar='PIXELS', type=int, help='Maximum width for stored images', required=False, default=4096)
Expand Down
10 changes: 10 additions & 0 deletions util/wait.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

if [ -z "${1}" ]
then
echo "Usage: wait.sh PID"
exit 1
fi

tail --pid=$1 -f /dev/null

0 comments on commit 22a65c9

Please sign in to comment.