This is the codebase for Conditional Sketch Generation and Completion
This README walks through how to train and sample from the sketch generation and completion model.
We have created a demo webapp to test out the system and put on display some examples.
Take a look at https://matttaylordev.pythonanywhere.com
You must install Python >= 3.11
Clone this repository and navigate to it in your terminal. Then install the project requirements.
pip install -r requirements.txtPytorch may have system specific requirements you can find a wheel that works for your system at the following site:
https://pytorch.org/get-started/locally/
The module dataset.py for downloading, organizing, preprocessing and using the datasets. An example of how to use the datasets is in the notebook dataset_visualization.ipynb
We primarily use a stratified version of the Quick, Draw! Dataset and train with a 85% train, 7.5% validation, 7.5% test split.
To train your model, you should first decide some hyperparameters. Hyperparameters are split up into four groups:
-
- Dataset
-
- Tokenizer encoding
-
- Model architecture
-
- Training configuration
Here is an example of setting up training using either Jupyter Notebook or Python + Toml Configuration
Create a new notebook in the project directory or get started from an existing one in the experiments directory. Our existing training experiment notebooks are named sketch_experiments_*.ipynb
Below is an example hyperparameter configuration from example.ipynb
from dataset import QuickDrawDataset
from sketch_tokenizers import DeltaPenPositionTokenizer
from models import SketchTransformerConditional
from runner import SketchTrainer
label_names = ["bird", "crab", "guitar"]
dataset = QuickDrawDataset(label_names=label_names)
tokenizer = DeltaPenPositionTokenizer(bins=32)
model = SketchTransformerConditional(
vocab_size=len(tokenizer.vocab),
d_model=512,
nhead=8,
num_layers=8,
max_len=200,
num_classes=len(label_names),
)
training_config = {
"batch_size": 128,
"num_epochs": 15,
"learning_rate": 1e-4,
"log_dir": "logs/sketch_transformer_experiment_2",
"splits": [0.85, 0.075, 0.075],
# "checkpoint_path": "logs/path/to/existing/checkpoint/model_checkpoint.pt"
}
trainer = SketchTrainer(model, dataset, tokenizer, training_config)Then in the next cell run
trainer.train_mixed(training_config["num_epochs"])If you prefer to train the model in the terminal
First create or modify an existing training toml configuration file. We provide some examples in the configs directory. Below is an example training configuration:
[dataset]
class = "QuickDrawDataset"
label_names = ["bird", "crab", "guitar"]
[tokenizer]
class = "DeltaPenPositionTokenizer"
bins = 32
[model]
class = "SketchTransformerConditional"
d_model = 512
nhead = 8
num_layers = 8
max_len = 200
[training]
batch_size = 128
num_epochs = 15
learning_rate = 1e-4
log_dir = "logs/sketch_transformer_example"
use_padding_mask = false
splits = [0.85, 0.075, 0.075] # 85% train, 7.5% validation, 7.5% test
# Resume from a specific model checkpoint
# checkpoint_path = "logs/path/to/existing/checkpoint/model_checkpoint.pt"Then run python main.py --config path/to/config.toml
Training logs and checkpoints are saved to the log_dir. You can follow the notebook experiments/sample_outputs.ipynb to load and sample from a saved model. To visualize the metrics collected while training run tensorboard --logdir log_dir and navigate to http://localhost:6006/