Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Un-hardcode "cuda" as default device name #49

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sgm/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler)
from sgm.util import load_model_from_config
from sgm.util import load_model_from_config, get_default_device_name


class ModelArchitecture(str, Enum):
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(
model_id: ModelArchitecture,
model_path="checkpoints",
config_path="configs/inference",
device="cuda",
device: Optional[str] = None,
use_fp16=True,
) -> None:
if model_id not in model_specs:
Expand All @@ -167,10 +167,10 @@ def __init__(
self.specs = model_specs[self.model_id]
self.config = str(pathlib.Path(config_path, self.specs.config))
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
self.device = device
self.device = device or get_default_device_name()
self.model = self._load_model(device=device, use_fp16=use_fp16)

def _load_model(self, device="cuda", use_fp16=True):
def _load_model(self, *, device, use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
if model is None:
Expand Down
31 changes: 22 additions & 9 deletions sgm/inference/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import numpy as np
import torch
from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from PIL import Image
from torch import autocast

from sgm.util import append_dims
from sgm.util import append_dims, safe_autocast, get_default_device_name


class WatermarkEmbedder:
Expand Down Expand Up @@ -111,21 +110,24 @@ def do_sample(
batch2model_input: Optional[List] = None,
return_latents=False,
filter=None,
device="cuda",
device: Optional[str] = None,
):
if not device:
device = get_default_device_name()
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []

with torch.no_grad():
with autocast(device) as precision_scope:
with safe_autocast(device):
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
device=device,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
Expand Down Expand Up @@ -170,7 +172,13 @@ def denoiser(input, sigma, c):
return samples


def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
def get_batch(
keys,
value_dict,
N: Union[List, ListConfig],
*,
device: str,
):
# Hardcoded demo setups; might undergo some changes in the future

batch = {}
Expand Down Expand Up @@ -227,7 +235,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
return batch, batch_uc


def get_input_image_tensor(image: Image.Image, device="cuda"):
def get_input_image_tensor(image: Image.Image, device: Optional[str] = None):
if not device:
device = get_default_device_name()
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
Expand All @@ -252,15 +262,18 @@ def do_img2img(
return_latents=False,
skip_encode=False,
filter=None,
device="cuda",
device: Optional[str] = None,
):
if not device:
device = get_default_device_name()
with torch.no_grad():
with autocast(device) as precision_scope:
with safe_autocast(device):
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
device=device,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
Expand Down
23 changes: 18 additions & 5 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple, Union, Optional

import pytorch_lightning as pl
import torch
Expand All @@ -12,8 +12,15 @@
from ..modules.autoencoding.temporal_ae import VideoDecoder
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from ..modules.ema import LitEma
from ..util import (default, disabled_train, get_obj_from_str,
instantiate_from_config, log_txt_as_img)
from ..util import (
default,
disabled_train,
get_default_device_name,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
safe_autocast,
)


class DiffusionEngine(pl.LightningModule):
Expand Down Expand Up @@ -114,14 +121,20 @@ def get_input(self, batch):
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]

def _first_stage_autocast_context(self):
return safe_autocast(
device=get_default_device_name(),
enabled=not self.disable_first_stage_autocast,
)

@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])

n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
with self._first_stage_autocast_context():
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
Expand All @@ -139,7 +152,7 @@ def encode_first_stage(self, x):
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
with self._first_stage_autocast_context():
for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
Expand Down
12 changes: 9 additions & 3 deletions sgm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@
from torch.utils.checkpoint import checkpoint

from ...modules.attention import SpatialTransformer
from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear,
normalization,
timestep_embedding, zero_module)
from ...modules.diffusionmodules.util import (
avg_pool_nd,
checkpoint,
conv_nd,
linear,
normalization,
timestep_embedding,
zero_module,
)
from ...modules.video_attention import SpatialVideoTransformer
from ...util import exists

Expand Down
17 changes: 11 additions & 6 deletions sgm/modules/diffusionmodules/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from omegaconf import ListConfig, OmegaConf
from tqdm import tqdm

from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
linear_multistep_coeff,
to_d, to_neg_log_sigma,
to_sigma)
from ...util import append_dims, default, instantiate_from_config
from ...modules.diffusionmodules.sampling_utils import (
get_ancestral_step,
linear_multistep_coeff,
to_d,
to_neg_log_sigma,
to_sigma,
)
from ...util import append_dims, default, instantiate_from_config, get_default_device_name

DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}

Expand All @@ -25,8 +28,10 @@ def __init__(
num_steps: Union[int, None] = None,
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
verbose: bool = False,
device: str = "cuda",
device: Union[str, None] = None,
):
if device is None:
device = get_default_device_name()
self.num_steps = num_steps
self.discretization = instantiate_from_config(discretization_config)
self.guider = instantiate_from_config(
Expand Down
44 changes: 31 additions & 13 deletions sgm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,17 @@
from ...modules.diffusionmodules.util import (extract_into_tensor,
make_beta_schedule)
from ...modules.distributions.distributions import DiagonalGaussianDistribution
from ...util import (append_dims, autocast, count_params, default,
disabled_train, expand_dims_like, instantiate_from_config)
from ...util import (
append_dims,
autocast,
count_params,
default,
disabled_train,
expand_dims_like,
get_default_device_name,
instantiate_from_config,
safe_autocast,
)


class AbstractEmbModel(nn.Module):
Expand Down Expand Up @@ -225,7 +234,9 @@ def forward(self, c):
c = c[:, None, :]
return c

def get_unconditional_conditioning(self, bs, device="cuda"):
def get_unconditional_conditioning(self, bs, device=None):
if device is None:
device = get_default_device_name()
uc_class = (
self.n_classes - 1
) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
Expand All @@ -250,9 +261,10 @@ class FrozenT5Embedder(AbstractEmbModel):
"""Uses the T5 transformer encoder for text"""

def __init__(
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
self, version="google/t5-v1_1-xxl", device=None, max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
device = device or get_default_device_name()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
Expand All @@ -277,7 +289,7 @@ def forward(self, text):
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
with safe_autocast(get_default_device_name(), enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
Expand All @@ -292,9 +304,10 @@ class FrozenByT5Embedder(AbstractEmbModel):
"""

def __init__(
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
self, version="google/byt5-base", device=None, max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
device = device or get_default_device_name()
self.tokenizer = ByT5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
Expand All @@ -319,7 +332,7 @@ def forward(self, text):
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
with safe_autocast(get_default_device_name(), enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
Expand All @@ -336,14 +349,15 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
def __init__(
self,
version="openai/clip-vit-large-patch14",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
always_return_pooled=False,
): # clip-vit-base-patch32
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -404,14 +418,15 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
always_return_pooled=False,
legacy=True,
):
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch,
Expand Down Expand Up @@ -506,12 +521,13 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
layer="last",
):
super().__init__()
device = device or get_default_device_name()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device("cpu"), pretrained=version
Expand Down Expand Up @@ -576,7 +592,7 @@ def __init__(
self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
device=None,
max_length=77,
freeze=True,
antialias=True,
Expand All @@ -588,6 +604,7 @@ def __init__(
init_device=None,
):
super().__init__()
device = device or get_default_device_name()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device(default(init_device, "cpu")),
Expand Down Expand Up @@ -733,11 +750,12 @@ def __init__(
self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
device=None,
clip_max_length=77,
t5_max_length=77,
):
super().__init__()
device = device or get_default_device_name()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length
)
Expand Down Expand Up @@ -999,7 +1017,7 @@ def forward(
noise = torch.randn_like(vid)
vid = vid + noise * append_dims(sigmas, vid.ndim)

with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
with safe_autocast(get_default_device_name(), enabled=not self.disable_encoder_autocast):
n_samples = (
self.en_and_decode_n_samples_a_time
if self.en_and_decode_n_samples_a_time is not None
Expand Down
Loading