Skip to content

[Hackability Refactor] Collapse export_util into export.py #1057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 260 additions & 8 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,15 @@

from torch.export import Dim

try:
executorch_export_available = True
from export_util.export_et import export_model as export_model_et
except Exception as e:
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
executorch_export_available = False


default_device = "cpu"


"""
Export for Server
"""


def export_for_server(
model: nn.Module,
device: Optional[str] = "cpu",
Expand Down Expand Up @@ -79,6 +77,260 @@ def export_for_server(
return so


"""
Export for ExecuTorch

TODO (https://github.com/pytorch/torchchat/issues/1058): Replace
replace_attention_with_custom_sdpa_attention with ET's implementation
"""

try:
executorch_export_available = True

import logging

from typing import Any, Dict, Tuple, Union

import executorch.exir as exir

from build.model import apply_rotary_emb, Attention
from build.utils import get_precision

from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackDynamicallyQuantizedPartitioner,
)
from executorch.exir import EdgeProgramManager, to_edge

from executorch.exir.capture._config import (
EdgeCompileConfig,
ExecutorchBackendConfig,
)
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import (
ConstraintBasedSymShapeEvalPass,
)
from executorch.exir.tracer import Value

from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram

default_device = "cpu"

_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
_check_ir_validity=True,
)

class CustomKVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype):
super().__init__()

dtype = torch.float

# This is flipped around from what is in build.model's KVCache
cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
self.register_buffer(
"k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)
self.register_buffer(
"v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
)

def update(self, input_pos, k_val, v_val):
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val.float()
v_out[:, :, input_pos] = v_val.float()

return k_out, v_out

class CustomSDPAAttention(nn.Module):
def __init__(self, attention: Attention):
super().__init__()

self.wq = attention.wq
self.wk = attention.wk
self.wv = attention.wv

self.wo = attention.wo

max_batch_size, n_heads, max_seq_length, head_dim = (
attention.kv_cache.k_cache.shape
)
cache_dtype = attention.kv_cache.k_cache.dtype
self.kv_cache = CustomKVCache(
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
)

self.n_heads = attention.n_heads
self.head_dim = attention.head_dim
self.n_local_heads = attention.n_local_heads
self.dim = attention.dim

def forward(self, x, freqs_cis, mask, input_pos=None):
bsz, seqlen, _ = x.shape

q = self.wq(x)
k = self.wk(x)
v = self.wv(x)

q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

q = apply_rotary_emb(q, freqs_cis).to(dtype=torch.float)
k = apply_rotary_emb(k, freqs_cis).to(dtype=torch.float)
v = v.to(dtype=torch.float)

# KV cache should always be enabled
assert self.kv_cache is not None
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
input_pos[-1].item(),
seqlen,
)
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
return self.wo(output)

def replace_attention_with_custom_sdpa_attention(module: nn.Module):
from executorch.examples.models.llama2.custom_ops import ( # noqa
sdpa_with_kv_cache,
)

for name, child in module.named_children():
if isinstance(child, Attention):
setattr(module, name, CustomSDPAAttention(child))
else:
replace_attention_with_custom_sdpa_attention(child)

def _to_core_aten(
model: Union[torch.fx.GraphModule, torch.nn.Module],
example_inputs: Tuple[Value, ...],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
verbose=True,
) -> ExportedProgram:
# post autograd export. eventually this will become .to_core_aten
if not isinstance(model, torch.fx.GraphModule) and not isinstance(
model, torch.nn.Module
):
raise ValueError(
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
)
core_aten_ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
if verbose:
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
return core_aten_ep

def _core_aten_to_edge(
core_aten_exir_ep: ExportedProgram,
edge_constant_methods: Optional[Dict[str, Any]] = None,
edge_compile_config=None,
verbose=True,
) -> EdgeProgramManager:
if not edge_compile_config:
edge_compile_config = exir.EdgeCompileConfig(
_check_ir_validity=False, # quant ops currently break ir verification
)
edge_manager: EdgeProgramManager = to_edge(
core_aten_exir_ep,
constant_methods=edge_constant_methods,
compile_config=edge_compile_config,
)
if verbose:
logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}")
return edge_manager

def export_to_edge(
model: Union[torch.fx.GraphModule, torch.nn.Module],
example_inputs: Tuple[Value, ...],
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
edge_constant_methods: Optional[Dict[str, Any]] = None,
edge_compile_config=_EDGE_COMPILE_CONFIG,
verbose=True,
) -> EdgeProgramManager:
core_aten_ep = _to_core_aten(
model, example_inputs, dynamic_shapes, verbose=verbose
)
return _core_aten_to_edge(
core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
)

def export_for_et(model, device, output_path, args=None) -> str: # noqa: C901

input = (
torch.tensor([[1]], dtype=torch.long, device=device),
torch.tensor([0], dtype=torch.long, device=device),
)

state_dict = model.state_dict()
state_dict_dtype = state_dict[next(iter(state_dict))].dtype
target_precision = get_precision()
dynamic_shapes = None

# TODO: need to use kv sdpa?
edge_config = EdgeCompileConfig(
_check_ir_validity=False,
_skip_type_promotion=bool(target_precision == torch.float16),
)

if target_precision == torch.float16 or target_precision == torch.bfloat16:
if state_dict_dtype != torch.float16:
print("model.to torch.float16")
model = model.to(dtype=torch.float16)
state_dict_dtype = torch.float16
elif target_precision == torch.float32:
if state_dict_dtype != torch.float32:
print("model.to torch.float32")
model = model.to(dtype=torch.float32)
elif target_precision == torch.bfloat16:
print("model.to torch.bfloat16")
model = model.to(dtype=torch.bfloat16)
else:
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")

replace_attention_with_custom_sdpa_attention(model)
with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
), torch.no_grad():
m = capture_pre_autograd_graph(model, input, dynamic_shapes=dynamic_shapes)

edge_manager = export_to_edge(
m,
input,
dynamic_shapes=dynamic_shapes,
edge_compile_config=edge_config,
)
edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner())
export_program = edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_constant_segment=True,
extract_delegate_segments=True,
passes=[
QuantFusionPass(),
],
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)

print("The methods are: ", export_program.methods)
with open(output_path, "wb") as f:
export_program.write_to_file(f)

return output_path

except Exception as e:
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
executorch_export_available = False


"""
Exporting Flow
"""


def main(args):
builder_args = BuilderArgs.from_args(args)
quantize = args.quantize
Expand Down Expand Up @@ -153,7 +405,7 @@ def main(args):
output_pte_path = str(os.path.abspath(output_pte_path))
if executorch_export_available:
print(f"Exporting model using ExecuTorch to {output_pte_path}")
export_model_et(
export_for_et(
model_to_pte, builder_args.device, args.output_pte_path, args
)
else:
Expand Down
Loading
Loading