Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/data
weights.zip
diffusers-cache
weights
.git
9 changes: 6 additions & 3 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
build:
gpu: true
cuda: "11.6"
python_version: "3.10"
cuda: "11.8"
python_version: "3.11"
python_packages:
- "diffusers==0.11.1"
- "torch==1.13.0"
- "torch==2.0.1"
- "ftfy==6.1.1"
- "scipy==1.9.3"
- "transformers==4.25.1"
- "accelerate==0.15.0"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.2/pget" && chmod +x /usr/local/bin/pget


predict: "predict.py:Predictor"
44 changes: 36 additions & 8 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
from typing import List
import subprocess
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from cog import BasePredictor, Input, Path
Expand All @@ -21,11 +23,9 @@
from transformers import CLIPFeatureExtractor


SAFETY_MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"

if not os.path.exists("weights"):
raise ValueError("dreambooth weights not found")
SAFETY_MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_URL = 'https://weights.replicate.delivery/default/dreambooth-safety-checker.tar'

DEFAULT_HEIGHT = 512
DEFAULT_WIDTH = 512
Expand All @@ -45,18 +45,38 @@
DEFAULT_PROMPT = "a photo of an astronaut riding a horse on mars"


def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", str(url), dest], close_fds=False)
print("downloading took: ", time.time() - start)


class Predictor(BasePredictor):
def setup(self):

weights = os.environ.get('DREAMBOOTH_WEIGHTS')

if weights is None:
self.tuned = False
return

if not os.path.exists("/src/weights"):
download_weights(weights, "/src/weights")

if not os.path.exists(SAFETY_MODEL_CACHE):
download_weights(SAFETY_MODEL_URL, SAFETY_MODEL_CACHE)

"""Load the model into memory to make running multiple predictions efficient"""
print("Loading Safety pipeline...")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_MODEL_ID,
cache_dir=SAFETY_MODEL_CACHE,
os.path.join(SAFETY_MODEL_CACHE, "stable-diffusion-safety-checker"),
torch_dtype=torch.float16,
local_files_only=True,
).to("cuda")
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32", cache_dir=SAFETY_MODEL_CACHE
os.path.join(SAFETY_MODEL_CACHE, "clip-vit-base-patch32")
)

print("Loading SD pipeline...")
Expand All @@ -77,6 +97,8 @@ def setup(self):
feature_extractor=self.txt2img_pipe.feature_extractor,
).to("cuda")

self.tuned = True

@torch.inference_mode()
def predict(
self,
Expand Down Expand Up @@ -138,6 +160,12 @@ def predict(
),
) -> List[Path]:
"""Run a single prediction on the model"""

if not self.tuned:
raise ValueError(
"This is a template model."
)

if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
Expand Down
10 changes: 7 additions & 3 deletions script/download-weights
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import shutil
import torch
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
Expand All @@ -17,9 +18,12 @@ os.makedirs(MODEL_CACHE, exist_ok=True)

saftey_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_MODEL_ID,
cache_dir=MODEL_CACHE,
torch_dtype=torch.float16,
)
saftey_checker.save_pretrained(MODEL_CACHE + "/stable-diffusion-safety-checker")

CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32", cache_dir=MODEL_CACHE
cfe = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
cfe.save_pretrained(MODEL_CACHE + "/clip-vit-base-patch32")

80 changes: 0 additions & 80 deletions script/prep.sh

This file was deleted.