Skip to content

Commit feead18

Browse files
authored
Return None as dynamic shape when enable_dynamic_shape is False
Differential Revision: D72805966 Pull Request resolved: #10073
1 parent 19c12c6 commit feead18

File tree

5 files changed

+143
-7
lines changed

5 files changed

+143
-7
lines changed

extension/llm/export/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ runtime.python_library(
2222
"//bento/...",
2323
"//bento_kernels/...",
2424
"//executorch/examples/...",
25+
"//executorch/extension/llm/...",
2526
"//meta_intern_odllm/...",
2627
],
2728
deps = [

extension/llm/export/builder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def _get_dynamic_shape(self) -> Any:
178178
return self.dynamic_shapes
179179

180180
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
181-
182-
if not self.use_kv_cache:
183-
# Only one input argument: tokens
184-
self.dynamic_shapes = ({1: dim},)
185-
elif self.enable_dynamic_shape:
186-
# Two input arguments: tokens and input_pos but input_pos is static shape
187-
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
181+
if self.enable_dynamic_shape:
182+
if not self.use_kv_cache:
183+
# Only one input argument: tokens
184+
self.dynamic_shapes = ({1: dim},)
185+
else:
186+
# Two input arguments: tokens and input_pos but input_pos is static shape
187+
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
188188
else:
189189
# Two input arguments: tokens and input_pos but both are of static shape
190190
self.dynamic_shapes = None

extension/llm/export/test/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
8+
9+
oncall("executorch")
10+
11+
runtime.python_test(
12+
name = "test_builder",
13+
srcs = ["test_builder.py"],
14+
deps = [
15+
"//executorch/extension/llm/export:export_lib",
16+
"//caffe2:torch",
17+
],
18+
)

extension/llm/export/test/__init__.py

Whitespace-only changes.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
import unittest
9+
from unittest.mock import MagicMock
10+
11+
import torch
12+
13+
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
14+
15+
16+
class TestLLMEdgeManager(unittest.TestCase):
17+
def setUp(self) -> None:
18+
# Create a mock model
19+
self.mock_model = MagicMock()
20+
self.modelname = "test_model"
21+
self.max_seq_len = 2048
22+
self.dtype = DType.fp32
23+
self.example_inputs = (torch.zeros((1, 10), dtype=torch.long),)
24+
self.example_kwarg_inputs = {"input_pos": torch.tensor([0])}
25+
26+
def test_get_dynamic_shape_with_preset_dynamic_shapes(self) -> None:
27+
"""Test that _get_dynamic_shape returns preset dynamic_shapes if available."""
28+
# Create a manager with preset dynamic_shapes
29+
preset_dynamic_shapes = {"preset": "shapes"}
30+
manager = LLMEdgeManager(
31+
model=self.mock_model,
32+
modelname=self.modelname,
33+
max_seq_len=self.max_seq_len,
34+
dtype=self.dtype,
35+
use_kv_cache=False,
36+
example_inputs=self.example_inputs,
37+
dynamic_shapes=preset_dynamic_shapes,
38+
)
39+
40+
# Call _get_dynamic_shape and verify it returns the preset value
41+
result = manager._get_dynamic_shape()
42+
self.assertEqual(result, preset_dynamic_shapes)
43+
44+
def test_get_dynamic_shape_with_dynamic_shape_enabled_no_kv_cache(self) -> None:
45+
"""Test _get_dynamic_shape when enable_dynamic_shape=True and use_kv_cache=False."""
46+
# Create a manager with enable_dynamic_shape=True and use_kv_cache=False
47+
manager = LLMEdgeManager(
48+
model=self.mock_model,
49+
modelname=self.modelname,
50+
max_seq_len=self.max_seq_len,
51+
dtype=self.dtype,
52+
use_kv_cache=False,
53+
example_inputs=self.example_inputs,
54+
enable_dynamic_shape=True,
55+
)
56+
57+
# Call _get_dynamic_shape
58+
result = manager._get_dynamic_shape()
59+
60+
# Verify the result has the expected structure
61+
self.assertIsInstance(result, tuple)
62+
self.assertEqual(len(result), 1)
63+
self.assertIsInstance(result[0], dict)
64+
self.assertIn(1, result[0])
65+
# Check that the value at key 1 is a torch.export.Dim with the correct max value
66+
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
67+
68+
def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> None:
69+
"""Test _get_dynamic_shape when enable_dynamic_shape=True and use_kv_cache=True."""
70+
# Create a manager with enable_dynamic_shape=True and use_kv_cache=True
71+
manager = LLMEdgeManager(
72+
model=self.mock_model,
73+
modelname=self.modelname,
74+
max_seq_len=self.max_seq_len,
75+
dtype=self.dtype,
76+
use_kv_cache=True,
77+
example_inputs=self.example_inputs,
78+
enable_dynamic_shape=True,
79+
)
80+
81+
# Call _get_dynamic_shape
82+
result = manager._get_dynamic_shape()
83+
84+
# Verify the result has the expected structure
85+
self.assertIsInstance(result, tuple)
86+
self.assertEqual(len(result), 2)
87+
88+
# Check first element (tokens dimension)
89+
self.assertIsInstance(result[0], dict)
90+
self.assertIn(1, result[0])
91+
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
92+
93+
# Check second element (input_pos dimension)
94+
self.assertIsInstance(result[1], dict)
95+
self.assertIn("input_pos", result[1])
96+
self.assertIsInstance(result[1]["input_pos"], dict)
97+
self.assertIn(0, result[1]["input_pos"])
98+
self.assertEqual(result[1]["input_pos"][0], 1)
99+
100+
def test_get_dynamic_shape_with_dynamic_shape_disabled(self) -> None:
101+
"""Test _get_dynamic_shape when enable_dynamic_shape=False."""
102+
# Create a manager with enable_dynamic_shape=False
103+
manager = LLMEdgeManager(
104+
model=self.mock_model,
105+
modelname=self.modelname,
106+
max_seq_len=self.max_seq_len,
107+
dtype=self.dtype,
108+
use_kv_cache=True, # Doesn't matter for this test
109+
example_inputs=self.example_inputs,
110+
enable_dynamic_shape=False,
111+
)
112+
113+
# Call _get_dynamic_shape
114+
result = manager._get_dynamic_shape()
115+
116+
# Verify the result is None
117+
self.assertIsNone(result)

0 commit comments

Comments
 (0)