Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions exir/_serialize/_named_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,32 @@
# pyre-strict

import hashlib

from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import torch

from executorch.exir._serialize.data_serializer import DataEntry
from executorch.exir.tensor_layout import TensorLayout


def _tensor_to_bytes(tensor: torch.Tensor) -> bytes:
"""Convert tensor to bytes using the fastest method available.

Uses numpy().tobytes() which is faster than bytes(untyped_storage())
for C-contiguous tensors. Falls back to untyped_storage() for
non-contiguous tensors (e.g., channels_last) to preserve memory layout.
"""
if not tensor.is_contiguous():
# For non-C-contiguous tensors (e.g., channels_last), use untyped_storage
# to preserve the actual memory layout
return bytes(tensor.untyped_storage())
if tensor.dtype == torch.bfloat16:
# BFloat16 is not supported by numpy, extract raw bytes via view
return tensor.view(torch.uint16).numpy().tobytes()
else:
return tensor.numpy().tobytes()


@dataclass
class NamedDataStoreOutput:
"""
Expand Down Expand Up @@ -169,7 +185,7 @@ def add_named_data(
f"Tensor {key} is a torch.Tensor, with tensor_layout {real_tensor_layout}. The provided tensor layout {tensor_layout} does not match."
)
tensor_layout = real_tensor_layout
byte_data = bytes(data.untyped_storage())
byte_data = _tensor_to_bytes(data)
else:
byte_data = data

Expand Down
1 change: 1 addition & 0 deletions exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int:
torch.bfloat16: ScalarType.BFLOAT16,
torch.quint4x2: ScalarType.QUINT4x2,
torch.uint16: ScalarType.UINT16,
torch.uint32: ScalarType.UINT32,
}


Expand Down
Loading