diff --git a/engines/python/setup/djl_python/stable_diffusion_inf2.py b/engines/python/setup/djl_python/stable_diffusion_inf2.py new file mode 100644 index 0000000000..c7171961c2 --- /dev/null +++ b/engines/python/setup/djl_python/stable_diffusion_inf2.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging +import os +import torch +import torch.nn as nn +import torch_neuronx +from djl_python.inputs import Input +from djl_python.outputs import Output +from io import BytesIO +from PIL import Image +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler +from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from diffusers.models.cross_attention import CrossAttention + + +class UNetWrap(nn.Module): + + def __init__(self, unet): + super().__init__() + self.unet = unet + + def forward(self, + sample, + timestep, + encoder_hidden_states, + cross_attention_kwargs=None): + out_tuple = self.unet(sample, + timestep, + encoder_hidden_states, + return_dict=False) + return out_tuple + + +class NeuronUNet(nn.Module): + + def __init__(self, unetwrap): + super().__init__() + self.unetwrap = unetwrap + self.config = unetwrap.unet.config + self.in_channels = unetwrap.unet.in_channels + self.device = unetwrap.unet.device + + def forward(self, + sample, + timestep, + encoder_hidden_states, + cross_attention_kwargs=None): + sample = self.unetwrap(sample, + timestep.float().expand((sample.shape[0], )), + encoder_hidden_states)[0] + return UNet2DConditionOutput(sample=sample) + + +def get_torch_dtype_from_str(dtype: str): + if dtype == "fp32": + return torch.float32 + elif dtype == "fp16": + return torch.float16 + raise ValueError( + f"Invalid data type: {dtype}. DeepSpeed currently only supports fp16 for stable diffusion" + ) + + +def get_attention_scores(self, query, key, attn_mask): + dtype = query.dtype + + if self.upcast_attention: + query = query.float() + key = key.float() + + if query.size() == key.size(): + attention_scores = cust_badbmm(key, query.transpose(-1, -2)) + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = torch.nn.functional.softmax(attention_scores, + dim=1).permute(0, 2, 1) + attention_probs = attention_probs.to(dtype) + + else: + attention_scores = cust_badbmm(query, key.transpose(-1, -2)) + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.to(dtype) + + return attention_probs + + +def cust_badbmm(a, b): + bmm = torch.bmm(a, b) + scaled = bmm * 0.125 + return scaled + + +class StableDiffusionService(object): + + def __init__(self): + self.pipeline = None + self.initialized = False + self.ds_config = None + self.logger = logging.getLogger() + self.model_id_or_path = None + self.data_type = None + self.device = None + self.tensor_parallel_degree = None + self.save_image_dir = None + + def initialize(self, properties: dict): + # model_id can point to huggingface model_id or local directory. + # If option.model_id points to a s3 bucket, we download it and set model_id to the download directory. + # Otherwise we assume model artifacts are in the model_dir + self.model_id_or_path = properties.get("model_id") or properties.get( + "model_dir") + self.tensor_parallel_degree = int( + properties.get("tensor_parallel_degree", 1)) + self.data_type = get_torch_dtype_from_str( + properties.get("dtype", "fp32")) + kwargs = {"torch_dtype": self.data_type} + if "use_auth_token" in properties: + kwargs["use_auth_token"] = properties["use_auth_token"] + + pipe = StableDiffusionPipeline.from_pretrained(self.model_id_or_path, + **kwargs) + pipe.scheduler = DPMSolverMultistepScheduler.from_config( + pipe.scheduler.config) + + if os.path.exists(os.path.join(self.model_id_or_path, + "compiled_model")): + logging.info("Loading pre-compiled model") + self.load_compiled( + pipe, os.path.join(self.model_id_or_path, "compiled_model")) + else: + self.runtime_compile(pipe) + + # Replace original cross-attention module with custom cross-attention module for better performance + CrossAttention.get_attention_scores = get_attention_scores + + self.pipeline = pipe + self.initialized = True + + def runtime_compile(self, pipe): + logging.warning( + "Runtime compilation is not recommended, please precompile the model" + ) + logging.info("Model compilation started...") + COMPILER_WORKDIR_ROOT = "/tmp/neuron_compiler" + pipe.unet = NeuronUNet(UNetWrap(pipe.unet)) + + sample_1b = torch.randn([1, 4, 64, 64]) + timestep_1b = torch.tensor(999).float().expand((1, )) + encoder_hidden_states_1b = torch.randn([1, 77, 1024]) + example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b + + logging.info("Compiling UNET...") + pipe.unet.unetwrap = torch_neuronx.trace( + pipe.unet.unetwrap, + example_inputs, + compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'), + compiler_args=[ + "--internal-hlo2penguin-options=--expand-batch-norm-training", + "--policy=3" + ]) + + device_ids = [idx for idx in range(self.tensor_parallel_degree)] + pipe.unet.unetwrap = torch_neuronx.DataParallel( + pipe.unet.unetwrap, device_ids, set_dynamic_batching=False) + + logging.info("Compiling post_quant_conv_in...") + # Compile vae post_quant_conv + post_quant_conv_in = torch.randn([1, 4, 64, 64]) + pipe.vae.post_quant_conv = torch_neuronx.trace( + pipe.vae.post_quant_conv, + post_quant_conv_in, + compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, + 'vae_post_quant_conv')) + + logging.info("Compiling VAE Decoder...") + # Compile vae decoder + decoder_in = torch.randn([1, 4, 64, 64]) + pipe.vae.decoder = torch_neuronx.trace( + pipe.vae.decoder, + decoder_in, + compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, + 'vae_decoder'), + compiler_args=[ + "--tensorizer-options=--max-dma-access-free-depth=3", + "--policy=3" + ]) + + def save_compiled(self, saved_dir): + # save compiled unet + unet_filename = os.path.join(saved_dir, 'unet/model.pt') + torch.jit.save(self.pipeline.unet.unetwrap, unet_filename) + # Save the compiled vae post_quant_conv + post_quant_conv_filename = os.path.join( + saved_dir, 'vae_post_quant_conv/model.pt') + torch.jit.save(self.pipeline.vae.post_quant_conv, post_quant_conv_filename) + # Save the compiled vae decoder + decoder_filename = os.path.join(saved_dir, + 'vae_decoder/model_nocast.pt') + torch.jit.save(self.pipeline.vae.decoder, decoder_filename) + + def load_compiled(self, pipe, saved_dir): + unet_filename = os.path.join(saved_dir, 'unet/model.pt') + torch.jit.load(unet_filename) + post_quant_conv_filename = os.path.join( + saved_dir, 'vae_post_quant_conv/model.pt') + pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename) + decoder_filename = os.path.join(saved_dir, 'vae_decoder/model.pt') + pipe.vae.decoder = torch.jit.load(decoder_filename) + + pipe.unet = NeuronUNet(UNetWrap(pipe.unet)) + device_ids = [idx for idx in range(self.tensor_parallel_degree)] + pipe.unet.unetwrap = torch_neuronx.DataParallel( + pipe.unet.unetwrap, device_ids, set_dynamic_batching=False) + + def infer(self, inputs: Input): + try: + content_type = inputs.get_property("Content-Type") + if content_type == "application/json": + request = inputs.get_as_json() + prompt = request.pop("prompt") + params = request.pop("parameters", {}) + result = self.pipeline(prompt, **params) + elif content_type and content_type.startswith("text/"): + prompt = inputs.get_as_string() + result = self.pipeline(prompt) + else: + init_image = Image.open(BytesIO( + inputs.get_as_bytes())).convert("RGB") + request = inputs.get_as_json("json") + prompt = request.pop("prompt") + params = request.pop("parameters", {}) + result = self.pipeline(prompt, image=init_image, **params) + + img = result.images[0] + buf = BytesIO() + img.save(buf, format="PNG") + byte_img = buf.getvalue() + outputs = Output().add(byte_img).add_property( + "content-type", "image/png") + + except Exception as e: + logging.exception("Neuron inference failed") + outputs = Output().error(str(e)) + return outputs diff --git a/engines/python/setup/djl_python/transformers-neuronx.py b/engines/python/setup/djl_python/transformers-neuronx.py index a53abeda92..b1295083cb 100644 --- a/engines/python/setup/djl_python/transformers-neuronx.py +++ b/engines/python/setup/djl_python/transformers-neuronx.py @@ -22,6 +22,7 @@ from transformers_neuronx.module import save_pretrained_split from transformers_neuronx.opt.model import OPTForSampling from djl_python import Input, Output +from djl_python.stable_diffusion_inf2 import StableDiffusionService from djl_python.streaming_utils import StreamingUtils model = None @@ -62,6 +63,7 @@ def convert_opt(self, amp): block.fc2.to(dtype) self.model.lm_head.to(dtype) logging.info(f"Saving to INF2 model to {load_path} ...") + logging.info(f"Saving to INF2 model to {load_path} ...") save_pretrained_split(self.model, load_path) with open(os.path.join(load_path, "verify"), "w") as f: f.writelines("opt-converted") @@ -207,7 +209,10 @@ def infer(self, inputs): def handle(inputs: Input): + global _service if not _service.initialized: + if "use_stable_diffusion" in inputs.get_properties(): + _service = StableDiffusionService() _service.initialize(inputs.get_properties()) if inputs.is_empty(): diff --git a/serving/docker/pytorch-inf2.Dockerfile b/serving/docker/pytorch-inf2.Dockerfile index 874e82d8fc..0d7a6599f9 100644 --- a/serving/docker/pytorch-inf2.Dockerfile +++ b/serving/docker/pytorch-inf2.Dockerfile @@ -17,6 +17,7 @@ ARG torch_neuronx_version=1.13.1.1.7.0 ARG transformers_neuronx_version=0.3.32 ARG transformers_version=4.28.1 ARG accelerate_version=0.18.0 +ARG diffusers_version=0.14.0 EXPOSE 8080 # Sets up Path for Neuron tools @@ -56,7 +57,7 @@ RUN mkdir -p /opt/djl/bin && cp scripts/telemetry.sh /opt/djl/bin && \ scripts/install_inferentia2.sh && \ pip install transformers==${transformers_version} accelerate==${accelerate_version} \ neuronx-cc==2.6.* torch_neuronx==${torch_neuronx_version} transformers-neuronx==${transformers_neuronx_version} \ - --extra-index-url=https://pip.repos.neuron.amazonaws.com && \ + diffusers==${diffusers_version} Pillow --extra-index-url=https://pip.repos.neuron.amazonaws.com && \ scripts/install_s5cmd.sh x64 && \ scripts/patch_oss_dlc.sh python && \ useradd -m -d /home/djl djl && \