Skip to content

Update Remote VAE blog #2714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 11 additions & 193 deletions remote_vae.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Therefore, we want to pilot an idea with the community — delegating the decodi

No data is stored or tracked, and code is open source. We made some changes to [huggingface-inference-toolkit](https://github.com/hlky/huggingface-inference-toolkit/tree/fix-text-support-binary) and use [custom handlers](https://huggingface.co/hlky/sd-vae-ft-mse/blob/main/handler.py).

This experimental feature is developed by [Diffusers 🧨](https://huggingface.co/docs/diffusers/hybrid_inference/overview)

**Table of contents**:

- [Getting started](#getting-started)
Expand All @@ -37,141 +39,14 @@ Below, we cover three use cases where we think this remote VAE inference would b
First, we have created a helper method for interacting with Remote VAEs.

> [!NOTE]
> We recommend installing `diffusers` from `main` to run the code.
> Install `diffusers` from `main` to run the code.
> `pip install git+https://github.com/huggingface/diffusers@main`

<details><summary>Code</summary>
<p>

```python
from typing import cast, List, Literal, Optional, Union

import base64
import io
import json
import requests
import torch
from PIL import Image

from diffusers.image_processor import VaeImageProcessor
from diffusers.video_processor import VideoProcessor
from safetensors.torch import _tobytes

DTYPE_MAP = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
}


def remote_decode(
endpoint: str,
tensor: torch.Tensor,
processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None,
do_scaling: bool = True,
output_type: Literal["mp4", "pil", "pt"] = "pil",
image_format: Literal["png", "jpg"] = "jpg",
partial_postprocess: bool = False,
input_tensor_type: Literal["base64", "binary"] = "base64",
output_tensor_type: Literal["base64", "binary"] = "base64",
height: Optional[int] = None,
width: Optional[int] = None,
) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]:
if tensor.ndim == 3 and height is None and width is None:
raise ValueError("`height` and `width` required for packed latents.")
if output_type == "pt" and partial_postprocess is False and processor is None:
raise ValueError(
"`processor` is required with `output_type='pt'` and `partial_postprocess=False`."
)
headers = {}
parameters = {
"do_scaling": do_scaling,
"output_type": output_type,
"partial_postprocess": partial_postprocess,
"shape": list(tensor.shape),
"dtype": str(tensor.dtype).split(".")[-1],
}
if height is not None and width is not None:
parameters["height"] = height
parameters["width"] = width
tensor_data = _tobytes(tensor, "tensor")
if input_tensor_type == "base64":
headers["Content-Type"] = "tensor/base64"
elif input_tensor_type == "binary":
headers["Content-Type"] = "tensor/binary"
if output_type == "pil" and image_format == "jpg" and processor is None:
headers["Accept"] = "image/jpeg"
elif output_type == "pil" and image_format == "png" and processor is None:
headers["Accept"] = "image/png"
elif (output_tensor_type == "base64" and output_type == "pt") or (
output_tensor_type == "base64"
and output_type == "pil"
and processor is not None
):
headers["Accept"] = "tensor/base64"
elif (output_tensor_type == "binary" and output_type == "pt") or (
output_tensor_type == "binary"
and output_type == "pil"
and processor is not None
):
headers["Accept"] = "tensor/binary"
elif output_type == "mp4":
headers["Accept"] = "text/plain"
if input_tensor_type == "base64":
kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
elif input_tensor_type == "binary":
kwargs = {"data": tensor_data}
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
if not response.ok:
raise RuntimeError(response.json())
if output_type == "pt" or (output_type == "pil" and processor is not None):
if output_tensor_type == "base64":
content = response.json()
output_tensor = base64.b64decode(content["inputs"])
parameters = content["parameters"]
shape = parameters["shape"]
dtype = parameters["dtype"]
elif output_tensor_type == "binary":
output_tensor = response.content
parameters = response.headers
shape = json.loads(parameters["shape"])
dtype = parameters["dtype"]
torch_dtype = DTYPE_MAP[dtype]
output_tensor = torch.frombuffer(
bytearray(output_tensor), dtype=torch_dtype
).reshape(shape)
if output_type == "pt":
if partial_postprocess:
output = [Image.fromarray(image.numpy()) for image in output_tensor]
if len(output) == 1:
output = output[0]
else:
if processor is None:
output = output_tensor
else:
if isinstance(processor, VideoProcessor):
output = cast(
List[Image.Image],
processor.postprocess_video(output_tensor, output_type="pil")[0],
)
else:
output = cast(
Image.Image,
processor.postprocess(output_tensor, output_type="pil")[0],
)
elif output_type == "pil" and processor is None:
output = Image.open(io.BytesIO(response.content)).convert("RGB")
elif output_type == "pil" and processor is not None:
output = [
Image.fromarray(image)
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255)
.round()
.astype("uint8")
]
elif output_type == "mp4":
output = response.content
return output
from diffusers.utils.remote_utils import remote_decode
```

</p>
Expand All @@ -188,6 +63,7 @@ Here, we show how to use the remote VAE on random tensors.
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
scaling_factor=0.18215,
)
```

Expand All @@ -209,6 +85,8 @@ image = remote_decode(
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```

Expand Down Expand Up @@ -246,70 +124,6 @@ with open("video.mp4", "wb") as f:
</video>
</figure>

### Options
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confusing for the average reader imo. Developers can check the docstrings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to keep it. We can add a note "Users are not required to go through this" or something like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was also outdated, will add it back based on current docstrings.


Let's review the available options.

```python
def remote_decode(
endpoint: str,
tensor: torch.Tensor,
processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None,
do_scaling: bool = True,
output_type: Literal["mp4", "pil", "pt"] = "pil",
image_format: Literal["png", "jpg"] = "jpg",
partial_postprocess: bool = False,
input_tensor_type: Literal["base64", "binary"] = "base64",
output_tensor_type: Literal["base64", "binary"] = "base64",
height: Optional[int] = None,
width: Optional[int] = None,
) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]:
```

#### Overview of decoding

There are 3 parts of decoding in a pipeline: `scaling` -> `decode` -> `postprocess`.

Options allow Remote VAE to be compatible with these different stages.

#### `processor`

With `output_type="pt"` the endpoint returns a `torch.Tensor` before `postprocess`. The final postprocessing and image creation is done locally.

With `output_type="pil"` on video models `processor=VideoProcessor()` is required for some local postprocessing.

#### `do_scaling`

- `do_scaling=False` allows Remote VAE to work as a drop-in replacement for `pipe.vae.decode`. Scaling should be applied to input before `remote_decode`.
- `do_scaling=True` scaling is applied by Remote VAE.

#### `output_type`

Image models support: `pil`, `pt`.

Video models support: `mp4`, `pil`, `pt`.

- `output_type="pil"` returns an image according to `image_format` for Image models and a tensor for Video models (equivalent to `postprocess_video(frames, output_type="pt")`) which has final postprocessing applied to create the frame images.
- `output_type="pt"` with `partial_postprocess=False` returns a `torch.Tensor` before `postprocess`. The final postprocessing and image creation is done locally.
- `output_type="pt"` with `partial_postprocess=True` returns a `torch.Tensor` with `postprocess` applied. The final image creation (`PIL.Image.fromarray`) is done locally. This reduces transfer compared to `partial_postprocess=False`.
- `output_type="mp4"` applies `postprocess_video(frames, output_type="pil")` then `export_to_video` and returns `bytes` of the `mp4`.

#### `input_tensor_type`/`output_tensor_type`

Choices `base64`, `binary`.

Using `binary` reduces transfer.

#### `image_format`

Choices `jpg`, `png`.

`jpg` is faster but lower quality.

#### `height`/`width`

Required for packed latents in Flux. Not required with `do_scaling=False` as `unpack` occurs before scaling.


### Generation

Expand Down Expand Up @@ -337,6 +151,7 @@ latent = pipe(
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should discuss how users can get this value I think.

Something like:

from huggingface_hub import hf_hub_download

file_path = hf_hub_download(repo_id, filename="vae/config.json")

...

)
image.save("test.jpg")
```
Expand Down Expand Up @@ -375,6 +190,8 @@ image = remote_decode(
tensor=latent,
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
image.save("test.jpg")
```
Expand Down Expand Up @@ -456,6 +273,7 @@ def decode_worker(q: queue.Queue):
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=item,
scaling_factor=0.18215,
)
display(image)
q.task_done()
Expand Down