Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,9 @@ def replace(self, **changes):
def type_string(self):
value_str = f"{self.value}" if self.value is not None else "?"
type_str = "int" if self.python_type is int else "bool"
if not self.is_static_constrained():
# For non-static values, only show the type
return f"symbolic {type_str}"
return f"{type_str} {value_str}"

def __repr__(self):
Expand Down Expand Up @@ -1209,6 +1212,9 @@ def replace(self, **changes):

def type_string(self):
value_str = f"{self.value}" if self.value is not None else "?"
if not self.is_static_constrained():
# For non-static values, only show the type
return "symbolic float"
return f"float {value_str}"

def __repr__(self):
Expand Down
38 changes: 38 additions & 0 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,44 @@ def test_cache_symbolic_values_nn_parameter_static_shape():
assert isinstance(bsym.output.shape[1], thunder.core.proxies.IntegerProxy)


def test_cache_symbolic_values_int_float_inputs():
def foo(a, b):
return a + b

jfoo = thunder_jit(foo, cache="symbolic values")

a = 1
b = 2.0
actual = jfoo(a, b)
expected = foo(a, b)

assert_close(actual, expected)
assert thunder.cache_misses(jfoo) == 1
assert thunder.cache_hits(jfoo) == 0

a = 2
b = 3.0
actual = jfoo(a, b)
expected = foo(a, b)

assert_close(actual, expected)
assert thunder.cache_misses(jfoo) == 1
assert thunder.cache_hits(jfoo) == 1

trc = thunder.last_traces(jfoo)[-1]
for bsym in trc.bound_symbols:
if bsym.sym.name == prims.PrimIDs.UNPACK_TRIVIAL:
assert isinstance(bsym.output, (IntegerProxy, FloatProxy))

trc_str = str(trc)
# Verify that symbolic inputs are not baked in as constants in the trace string
assert 'a: "int 1"' not in trc_str
assert 'b: "float 2.0"' not in trc_str

assert 'a: "symbolic int"' in trc_str
assert 'b: "symbolic float"' in trc_str


def test_specific_dataclass_returns():
import transformers

Expand Down
Loading