Skip to content

Commit

Permalink
add stable diffusion support on INF2 (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored May 4, 2023
1 parent a5607d1 commit 7e12511
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 41 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/llm_inf2_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ jobs:
python3 llm/client.py transformers_neuronx opt-1.3b-streaming
docker rm -f $(docker ps -aq)
sudo rm -rf models
- name: Test stable diffusion with handler
working-directory: tests/integration
run: |
rm -rf models
python3 llm/prepare.py transformers_neuronx stable-diffusion-2.1-base-neuron
./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models pytorch-inf2-2 \
serve
curl http://127.0.0.1:8080/models
python3 llm/client.py stable-diffusion stable-diffusion-2.1-base-neuron
docker rm -f $(docker ps -aq)
sudo rm -rf models
- name: On fail step
if: ${{ failure() }}
working-directory: tests/integration
Expand Down
262 changes: 262 additions & 0 deletions engines/python/setup/djl_python/stable_diffusion_inf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
import os
import torch
import torch.nn as nn
import torch_neuronx
from djl_python.inputs import Input
from djl_python.outputs import Output
from io import BytesIO
from PIL import Image
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.cross_attention import CrossAttention


class UNetWrap(nn.Module):

def __init__(self, unet):
super().__init__()
self.unet = unet

def forward(self,
sample,
timestep,
encoder_hidden_states,
cross_attention_kwargs=None):
out_tuple = self.unet(sample,
timestep,
encoder_hidden_states,
return_dict=False)
return out_tuple


class NeuronUNet(nn.Module):

def __init__(self, unetwrap):
super().__init__()
self.unetwrap = unetwrap
self.config = unetwrap.unet.config
self.in_channels = unetwrap.unet.in_channels
self.device = unetwrap.unet.device

def forward(self,
sample,
timestep,
encoder_hidden_states,
cross_attention_kwargs=None):
sample = self.unetwrap(sample,
timestep.float().expand((sample.shape[0], )),
encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)


def get_torch_dtype_from_str(dtype: str):
if dtype == "fp32":
return torch.float32
elif dtype == "fp16":
return torch.float16
raise ValueError(
f"Invalid data type: {dtype}. DeepSpeed currently only supports fp16 for stable diffusion"
)


def get_attention_scores(self, query, key, attn_mask):
dtype = query.dtype

if self.upcast_attention:
query = query.float()
key = key.float()

if query.size() == key.size():
attention_scores = cust_badbmm(key, query.transpose(-1, -2))

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = torch.nn.functional.softmax(attention_scores,
dim=1).permute(0, 2, 1)
attention_probs = attention_probs.to(dtype)

else:
attention_scores = cust_badbmm(query, key.transpose(-1, -2))

if self.upcast_softmax:
attention_scores = attention_scores.float()

attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.to(dtype)

return attention_probs


def cust_badbmm(a, b):
bmm = torch.bmm(a, b)
scaled = bmm * 0.125
return scaled


class StableDiffusionService(object):

def __init__(self):
self.pipeline = None
self.initialized = False
self.model_id_or_path = None
self.data_type = None
self.tensor_parallel_degree = None

def initialize(self, properties: dict):
# model_id can point to huggingface model_id or local directory.
# If option.model_id points to a s3 bucket, we download it and set model_id to the download directory.
# Otherwise, we assume model artifacts are in the model_dir
self.model_id_or_path = properties.get("model_id") or properties.get(
"model_dir")
self.tensor_parallel_degree = int(
properties.get("tensor_parallel_degree", 2))
self.data_type = get_torch_dtype_from_str(
properties.get("dtype", "fp32"))
kwargs = {"torch_dtype": self.data_type}
if "use_auth_token" in properties:
kwargs["use_auth_token"] = properties["use_auth_token"]

self.pipeline = StableDiffusionPipeline.from_pretrained(
self.model_id_or_path, **kwargs)
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipeline.scheduler.config)

# Replace original cross-attention module with custom cross-attention module for better performance
CrossAttention.get_attention_scores = get_attention_scores

if os.path.exists(os.path.join(self.model_id_or_path,
"compiled_model")):
logging.info("Loading pre-compiled model")
self.load_compiled(
os.path.join(self.model_id_or_path, "compiled_model"))
else:
self.runtime_compile()

if "save_compiled_model" in properties:
self.save_compiled(
os.path.join(properties.get("save_compiled_model"),
"compiled_model"))

device_ids = [idx for idx in range(self.tensor_parallel_degree)]
self.pipeline.unet.unetwrap = torch_neuronx.DataParallel(
self.pipeline.unet.unetwrap,
device_ids,
set_dynamic_batching=False)

self.initialized = True

def runtime_compile(self):
logging.warning(
"Runtime compilation is not recommended, please precompile the model"
)
logging.info("Model compilation started...")
COMPILER_WORKDIR_ROOT = "/tmp/neuron_compiler"
self.pipeline.unet = NeuronUNet(UNetWrap(self.pipeline.unet))

sample_1b = torch.randn([1, 4, 64, 64])
timestep_1b = torch.tensor(999).float().expand((1, ))
encoder_hidden_states_1b = torch.randn([1, 77, 1024])
example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b

logging.info("Compiling UNET...")
self.pipeline.unet.unetwrap = torch_neuronx.trace(
self.pipeline.unet.unetwrap,
example_inputs,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),
compiler_args=[
"--internal-hlo2penguin-options=--expand-batch-norm-training",
"--policy=3"
])

logging.info("Compiling post_quant_conv_in...")
# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 64, 64])
self.pipeline.vae.post_quant_conv = torch_neuronx.trace(
self.pipeline.vae.post_quant_conv,
post_quant_conv_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT,
'vae_post_quant_conv'))

logging.info("Compiling VAE Decoder...")
# Compile vae decoder
decoder_in = torch.randn([1, 4, 64, 64])
self.pipeline.vae.decoder = torch_neuronx.trace(
self.pipeline.vae.decoder,
decoder_in,
compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT,
'vae_decoder'),
compiler_args=[
"--tensorizer-options=--max-dma-access-free-depth=3",
"--policy=3"
])

def save_compiled(self, saved_dir):
if not os.path.exists(saved_dir):
os.makedirs(saved_dir)
# save compiled unet
unet_filename = os.path.join(saved_dir, 'unet.pt')
torch.jit.save(self.pipeline.unet.unetwrap, unet_filename)
# Save the compiled vae post_quant_conv
post_quant_conv_filename = os.path.join(saved_dir,
'vae_post_quant_conv.pt')
torch.jit.save(self.pipeline.vae.post_quant_conv,
post_quant_conv_filename)
# Save the compiled vae decoder
decoder_filename = os.path.join(saved_dir, 'vae_decoder.pt')
torch.jit.save(self.pipeline.vae.decoder, decoder_filename)

def load_compiled(self, saved_dir):
post_quant_conv_filename = os.path.join(saved_dir,
'vae_post_quant_conv.pt')
self.pipeline.vae.post_quant_conv = torch.jit.load(
post_quant_conv_filename)
decoder_filename = os.path.join(saved_dir, 'vae_decoder.pt')
self.pipeline.vae.decoder = torch.jit.load(decoder_filename)
self.pipeline.unet = NeuronUNet(UNetWrap(self.pipeline.unet))
unet_filename = os.path.join(saved_dir, 'unet.pt')
self.pipeline.unet.unetwrap = torch.jit.load(unet_filename)

def infer(self, inputs: Input):
try:
content_type = inputs.get_property("Content-Type")
if content_type == "application/json":
request = inputs.get_as_json()
prompt = request.pop("prompt")
params = request.pop("parameters", {})
result = self.pipeline(prompt, **params)
elif content_type and content_type.startswith("text/"):
prompt = inputs.get_as_string()
result = self.pipeline(prompt)
else:
init_image = Image.open(BytesIO(
inputs.get_as_bytes())).convert("RGB")
request = inputs.get_as_json("json")
prompt = request.pop("prompt")
params = request.pop("parameters", {})
result = self.pipeline(prompt, image=init_image, **params)

img = result.images[0]
buf = BytesIO()
img.save(buf, format="PNG")
byte_img = buf.getvalue()
outputs = Output().add(byte_img).add_property(
"content-type", "image/png")

except Exception as e:
logging.exception("Neuron inference failed")
outputs = Output().error(str(e))
return outputs
5 changes: 5 additions & 0 deletions engines/python/setup/djl_python/transformers-neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers_neuronx.module import save_pretrained_split
from transformers_neuronx.opt.model import OPTForSampling
from djl_python import Input, Output
from djl_python.stable_diffusion_inf2 import StableDiffusionService
from djl_python.streaming_utils import StreamingUtils

model = None
Expand Down Expand Up @@ -62,6 +63,7 @@ def convert_opt(self, amp):
block.fc2.to(dtype)
self.model.lm_head.to(dtype)
logging.info(f"Saving to INF2 model to {load_path} ...")
logging.info(f"Saving to INF2 model to {load_path} ...")
save_pretrained_split(self.model, load_path)
with open(os.path.join(load_path, "verify"), "w") as f:
f.writelines("opt-converted")
Expand Down Expand Up @@ -207,7 +209,10 @@ def infer(self, inputs):


def handle(inputs: Input):
global _service
if not _service.initialized:
if "use_stable_diffusion" in inputs.get_properties():
_service = StableDiffusionService()
_service.initialize(inputs.get_properties())

if inputs.is_empty():
Expand Down
3 changes: 2 additions & 1 deletion serving/docker/pytorch-inf2.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ARG torch_neuronx_version=1.13.1.1.7.0
ARG transformers_neuronx_version=0.3.32
ARG transformers_version=4.28.1
ARG accelerate_version=0.18.0
ARG diffusers_version=0.14.0
EXPOSE 8080

# Sets up Path for Neuron tools
Expand Down Expand Up @@ -56,7 +57,7 @@ RUN mkdir -p /opt/djl/bin && cp scripts/telemetry.sh /opt/djl/bin && \
scripts/install_inferentia2.sh && \
pip install transformers==${transformers_version} accelerate==${accelerate_version} \
neuronx-cc==2.6.* torch_neuronx==${torch_neuronx_version} transformers-neuronx==${transformers_neuronx_version} \
--extra-index-url=https://pip.repos.neuron.amazonaws.com && \
diffusers==${diffusers_version} Pillow --extra-index-url=https://pip.repos.neuron.amazonaws.com && \
scripts/install_s5cmd.sh x64 && \
scripts/patch_oss_dlc.sh python && \
useradd -m -d /home/djl djl && \
Expand Down
Loading

0 comments on commit 7e12511

Please sign in to comment.