-
Notifications
You must be signed in to change notification settings - Fork 875
Update Remote VAE blog #2714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Update Remote VAE blog #2714
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ Therefore, we want to pilot an idea with the community — delegating the decodi | |
|
||
No data is stored or tracked, and code is open source. We made some changes to [huggingface-inference-toolkit](https://github.com/hlky/huggingface-inference-toolkit/tree/fix-text-support-binary) and use [custom handlers](https://huggingface.co/hlky/sd-vae-ft-mse/blob/main/handler.py). | ||
|
||
This experimental feature is developed by [Diffusers 🧨](https://huggingface.co/docs/diffusers/hybrid_inference/overview) | ||
|
||
**Table of contents**: | ||
|
||
- [Getting started](#getting-started) | ||
|
@@ -37,141 +39,14 @@ Below, we cover three use cases where we think this remote VAE inference would b | |
First, we have created a helper method for interacting with Remote VAEs. | ||
|
||
> [!NOTE] | ||
> We recommend installing `diffusers` from `main` to run the code. | ||
> Install `diffusers` from `main` to run the code. | ||
> `pip install git+https://github.com/huggingface/diffusers@main` | ||
|
||
<details><summary>Code</summary> | ||
<p> | ||
|
||
```python | ||
from typing import cast, List, Literal, Optional, Union | ||
|
||
import base64 | ||
import io | ||
import json | ||
import requests | ||
import torch | ||
from PIL import Image | ||
|
||
from diffusers.image_processor import VaeImageProcessor | ||
from diffusers.video_processor import VideoProcessor | ||
from safetensors.torch import _tobytes | ||
|
||
DTYPE_MAP = { | ||
"float16": torch.float16, | ||
"float32": torch.float32, | ||
"bfloat16": torch.bfloat16, | ||
"uint8": torch.uint8, | ||
} | ||
|
||
|
||
def remote_decode( | ||
endpoint: str, | ||
tensor: torch.Tensor, | ||
processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None, | ||
do_scaling: bool = True, | ||
output_type: Literal["mp4", "pil", "pt"] = "pil", | ||
image_format: Literal["png", "jpg"] = "jpg", | ||
partial_postprocess: bool = False, | ||
input_tensor_type: Literal["base64", "binary"] = "base64", | ||
output_tensor_type: Literal["base64", "binary"] = "base64", | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]: | ||
if tensor.ndim == 3 and height is None and width is None: | ||
raise ValueError("`height` and `width` required for packed latents.") | ||
if output_type == "pt" and partial_postprocess is False and processor is None: | ||
raise ValueError( | ||
"`processor` is required with `output_type='pt'` and `partial_postprocess=False`." | ||
) | ||
headers = {} | ||
parameters = { | ||
"do_scaling": do_scaling, | ||
"output_type": output_type, | ||
"partial_postprocess": partial_postprocess, | ||
"shape": list(tensor.shape), | ||
"dtype": str(tensor.dtype).split(".")[-1], | ||
} | ||
if height is not None and width is not None: | ||
parameters["height"] = height | ||
parameters["width"] = width | ||
tensor_data = _tobytes(tensor, "tensor") | ||
if input_tensor_type == "base64": | ||
headers["Content-Type"] = "tensor/base64" | ||
elif input_tensor_type == "binary": | ||
headers["Content-Type"] = "tensor/binary" | ||
if output_type == "pil" and image_format == "jpg" and processor is None: | ||
headers["Accept"] = "image/jpeg" | ||
elif output_type == "pil" and image_format == "png" and processor is None: | ||
headers["Accept"] = "image/png" | ||
elif (output_tensor_type == "base64" and output_type == "pt") or ( | ||
output_tensor_type == "base64" | ||
and output_type == "pil" | ||
and processor is not None | ||
): | ||
headers["Accept"] = "tensor/base64" | ||
elif (output_tensor_type == "binary" and output_type == "pt") or ( | ||
output_tensor_type == "binary" | ||
and output_type == "pil" | ||
and processor is not None | ||
): | ||
headers["Accept"] = "tensor/binary" | ||
elif output_type == "mp4": | ||
headers["Accept"] = "text/plain" | ||
if input_tensor_type == "base64": | ||
kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}} | ||
elif input_tensor_type == "binary": | ||
kwargs = {"data": tensor_data} | ||
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers) | ||
if not response.ok: | ||
raise RuntimeError(response.json()) | ||
if output_type == "pt" or (output_type == "pil" and processor is not None): | ||
if output_tensor_type == "base64": | ||
content = response.json() | ||
output_tensor = base64.b64decode(content["inputs"]) | ||
parameters = content["parameters"] | ||
shape = parameters["shape"] | ||
dtype = parameters["dtype"] | ||
elif output_tensor_type == "binary": | ||
output_tensor = response.content | ||
parameters = response.headers | ||
shape = json.loads(parameters["shape"]) | ||
dtype = parameters["dtype"] | ||
torch_dtype = DTYPE_MAP[dtype] | ||
output_tensor = torch.frombuffer( | ||
bytearray(output_tensor), dtype=torch_dtype | ||
).reshape(shape) | ||
if output_type == "pt": | ||
if partial_postprocess: | ||
output = [Image.fromarray(image.numpy()) for image in output_tensor] | ||
if len(output) == 1: | ||
output = output[0] | ||
else: | ||
if processor is None: | ||
output = output_tensor | ||
else: | ||
if isinstance(processor, VideoProcessor): | ||
output = cast( | ||
List[Image.Image], | ||
processor.postprocess_video(output_tensor, output_type="pil")[0], | ||
) | ||
else: | ||
output = cast( | ||
Image.Image, | ||
processor.postprocess(output_tensor, output_type="pil")[0], | ||
) | ||
elif output_type == "pil" and processor is None: | ||
output = Image.open(io.BytesIO(response.content)).convert("RGB") | ||
elif output_type == "pil" and processor is not None: | ||
output = [ | ||
Image.fromarray(image) | ||
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255) | ||
.round() | ||
.astype("uint8") | ||
] | ||
elif output_type == "mp4": | ||
output = response.content | ||
return output | ||
from diffusers.utils.remote_utils import remote_decode | ||
``` | ||
|
||
</p> | ||
|
@@ -188,6 +63,7 @@ Here, we show how to use the remote VAE on random tensors. | |
image = remote_decode( | ||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", | ||
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16), | ||
scaling_factor=0.18215, | ||
) | ||
``` | ||
|
||
|
@@ -209,6 +85,8 @@ image = remote_decode( | |
tensor=torch.randn([1, 4096, 64], dtype=torch.float16), | ||
height=1024, | ||
width=1024, | ||
scaling_factor=0.3611, | ||
shift_factor=0.1159, | ||
) | ||
``` | ||
|
||
|
@@ -246,70 +124,6 @@ with open("video.mp4", "wb") as f: | |
</video> | ||
</figure> | ||
|
||
### Options | ||
|
||
Let's review the available options. | ||
|
||
```python | ||
def remote_decode( | ||
endpoint: str, | ||
tensor: torch.Tensor, | ||
processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None, | ||
do_scaling: bool = True, | ||
output_type: Literal["mp4", "pil", "pt"] = "pil", | ||
image_format: Literal["png", "jpg"] = "jpg", | ||
partial_postprocess: bool = False, | ||
input_tensor_type: Literal["base64", "binary"] = "base64", | ||
output_tensor_type: Literal["base64", "binary"] = "base64", | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]: | ||
``` | ||
|
||
#### Overview of decoding | ||
|
||
There are 3 parts of decoding in a pipeline: `scaling` -> `decode` -> `postprocess`. | ||
|
||
Options allow Remote VAE to be compatible with these different stages. | ||
|
||
#### `processor` | ||
|
||
With `output_type="pt"` the endpoint returns a `torch.Tensor` before `postprocess`. The final postprocessing and image creation is done locally. | ||
|
||
With `output_type="pil"` on video models `processor=VideoProcessor()` is required for some local postprocessing. | ||
|
||
#### `do_scaling` | ||
|
||
- `do_scaling=False` allows Remote VAE to work as a drop-in replacement for `pipe.vae.decode`. Scaling should be applied to input before `remote_decode`. | ||
- `do_scaling=True` scaling is applied by Remote VAE. | ||
|
||
#### `output_type` | ||
|
||
Image models support: `pil`, `pt`. | ||
|
||
Video models support: `mp4`, `pil`, `pt`. | ||
|
||
- `output_type="pil"` returns an image according to `image_format` for Image models and a tensor for Video models (equivalent to `postprocess_video(frames, output_type="pt")`) which has final postprocessing applied to create the frame images. | ||
- `output_type="pt"` with `partial_postprocess=False` returns a `torch.Tensor` before `postprocess`. The final postprocessing and image creation is done locally. | ||
- `output_type="pt"` with `partial_postprocess=True` returns a `torch.Tensor` with `postprocess` applied. The final image creation (`PIL.Image.fromarray`) is done locally. This reduces transfer compared to `partial_postprocess=False`. | ||
- `output_type="mp4"` applies `postprocess_video(frames, output_type="pil")` then `export_to_video` and returns `bytes` of the `mp4`. | ||
|
||
#### `input_tensor_type`/`output_tensor_type` | ||
|
||
Choices `base64`, `binary`. | ||
|
||
Using `binary` reduces transfer. | ||
|
||
#### `image_format` | ||
|
||
Choices `jpg`, `png`. | ||
|
||
`jpg` is faster but lower quality. | ||
|
||
#### `height`/`width` | ||
|
||
Required for packed latents in Flux. Not required with `do_scaling=False` as `unpack` occurs before scaling. | ||
|
||
|
||
### Generation | ||
|
||
|
@@ -337,6 +151,7 @@ latent = pipe( | |
image = remote_decode( | ||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", | ||
tensor=latent, | ||
scaling_factor=0.18215, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should discuss how users can get this value I think. Something like: from huggingface_hub import hf_hub_download
file_path = hf_hub_download(repo_id, filename="vae/config.json")
... |
||
) | ||
image.save("test.jpg") | ||
``` | ||
|
@@ -375,6 +190,8 @@ image = remote_decode( | |
tensor=latent, | ||
height=1024, | ||
width=1024, | ||
scaling_factor=0.3611, | ||
shift_factor=0.1159, | ||
) | ||
image.save("test.jpg") | ||
``` | ||
|
@@ -456,6 +273,7 @@ def decode_worker(q: queue.Queue): | |
image = remote_decode( | ||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", | ||
tensor=item, | ||
scaling_factor=0.18215, | ||
) | ||
display(image) | ||
q.task_done() | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confusing for the average reader imo. Developers can check the docstrings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to keep it. We can add a note "Users are not required to go through this" or something like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was also outdated, will add it back based on current docstrings.