Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INF2] add bf16 support to SD #700

Merged
merged 1 commit into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add bf16 support
  • Loading branch information
Qing Lan committed May 10, 2023
commit 92fa6e84b2766ab3f8e8621fb1cdb88d9c77707d
11 changes: 11 additions & 0 deletions .github/workflows/llm_inf2_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ jobs:
python3 llm/client.py stable-diffusion stable-diffusion-2.1-base-neuron
docker rm -f $(docker ps -aq)
sudo rm -rf models
- name: Test stable diffusion bf16 with handler
working-directory: tests/integration
run: |
rm -rf models
python3 llm/prepare.py transformers_neuronx stable-diffusion-2.1-base-neuron-bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need both bf16 and fp32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to test on both

./launch_container.sh deepjavalibrary/djl-serving:$DJLSERVING_DOCKER_TAG $PWD/models pytorch-inf2-2 \
serve
curl http://127.0.0.1:8080/models
python3 llm/client.py stable-diffusion stable-diffusion-2.1-base-neuron
docker rm -f $(docker ps -aq)
sudo rm -rf models
- name: On fail step
if: ${{ failure() }}
working-directory: tests/integration
Expand Down
25 changes: 14 additions & 11 deletions engines/python/setup/djl_python/stable_diffusion_inf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def forward(self,
timestep,
encoder_hidden_states,
cross_attention_kwargs=None):
sample = self.unetwrap(sample,
timestep.float().expand((sample.shape[0], )),
encoder_hidden_states)[0]
sample = self.unetwrap(
sample,
timestep.to(sample.dtype).expand((sample.shape[0], )),
encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)


Expand All @@ -78,10 +79,10 @@ def forward(self, emb, attention_mask=None):
def get_torch_dtype_from_str(dtype: str):
if dtype == "fp32":
return torch.float32
elif dtype == "fp16":
return torch.float16
elif dtype == "bf16":
return torch.bfloat16
raise ValueError(
f"Invalid data type: {dtype}. DeepSpeed currently only supports fp16 for stable diffusion"
f"Invalid data type: {dtype}. NeuronX currently only supports fp32 and bf16 for stable diffusion"
)


Expand Down Expand Up @@ -197,9 +198,11 @@ def runtime_compile(self):

self.pipeline.unet = NeuronUNet(UNetWrap(self.pipeline.unet))

sample_1b = torch.randn([1, 4, 64, 64])
timestep_1b = torch.tensor(999).float().expand((1, ))
encoder_hidden_states_1b = torch.randn([1, 77, 1024])
sample_1b = torch.randn([1, 4, 64, 64]).to(self.data_type)
timestep_1b = torch.tensor(999).float().expand(
(1, )).to(self.data_type)
encoder_hidden_states_1b = torch.randn([1, 77,
1024]).to(self.data_type)
example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b

logging.info("Compiling UNET...")
Expand All @@ -214,7 +217,7 @@ def runtime_compile(self):

logging.info("Compiling post_quant_conv_in...")
# Compile vae post_quant_conv
post_quant_conv_in = torch.randn([1, 4, 64, 64])
post_quant_conv_in = torch.randn([1, 4, 64, 64]).to(self.data_type)
self.pipeline.vae.post_quant_conv = torch_neuronx.trace(
self.pipeline.vae.post_quant_conv,
post_quant_conv_in,
Expand All @@ -223,7 +226,7 @@ def runtime_compile(self):

logging.info("Compiling VAE Decoder...")
# Compile vae decoder
decoder_in = torch.randn([1, 4, 64, 64])
decoder_in = torch.randn([1, 4, 64, 64]).to(self.data_type)
self.pipeline.vae.decoder = torch_neuronx.trace(
self.pipeline.vae.decoder,
decoder_in,
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@
"option.model_id": "s3://djl-llm/stable-diffusion-2-1-base-compiled/",
"option.tensor_parallel_degree": 2,
"option.use_stable_diffusion": True
},
"stable-diffusion-2.1-base-neuron-bf16": {
"option.model_id": "s3://djl-llm/stable-diffusion-2-1-base-compiled-bf16/",
"option.tensor_parallel_degree": 2,
"option.dtype": "bf16",
"option.use_stable_diffusion": True
}
}

Expand Down