Skip to content

Enforce that the input and output name are not optional and remove the logic to deal with non input name and output name logic. #1889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@ message TfConcreteFunctionHandle {
// SavedModel) but we want to allow calling the function with positional
// arguments.
//
// If this field is not set, this function can only be called with
// keyword-only arguments.
optional StrList input_names = 20;
repeated string input_names = 20;

// Similar to `input_names`, but for outputs.
optional StrList output_names = 30;
}

message StrList {
repeated string elements = 1;
repeated string output_names = 30;
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _is_dict_only(tree: TfSignature) -> bool:
return False
# LINT.ThenChange(//depot//learning/infra/mira/experimental/orbax_model/tensorflow/tf_compatible_optional_function_handler.cc)

_NamesAndSequence = Tuple[Sequence[str] | None, Sequence[Any] | None]
_NamesAndSequence = Tuple[Sequence[str], Sequence[Any]]


# We choose to rely solely on a concrete function's TF signature to
Expand All @@ -83,61 +83,25 @@ def _is_dict_only(tree: TfSignature) -> bool:
# TF signature a dict instead if they want to serve the function on
# Servomatic. If we find that there are too many users relying on this
# SavedModel behavior, we can revisit the decision here.
def _generate_names(
tree: TfSignature, prefix: str = '', *, fixed_name_pattern: bool
) -> _NamesAndSequence:
if not fixed_name_pattern and _is_dict_only(tree):
# If the input signature is dict-only, the function shouldn't be
# called with positional arguments anyway, so we don't generate
# names and just return None.
return None, None
def _generate_names(tree: TfSignature, prefix: str = '') -> _NamesAndSequence:
flat = obm.tree_util.flatten(tree)
return tuple(f'{prefix}_{i}' for i in range(len(flat))), flat


def _get_input_names(
tree: TfSignature, *, fixed_name_pattern: bool
) -> _NamesAndSequence:
return _generate_names(
tree, prefix='input', fixed_name_pattern=fixed_name_pattern
)


def _get_output_names(
tree: TfSignature, *, fixed_name_pattern: bool
) -> _NamesAndSequence:
return _generate_names(
tree, prefix='output', fixed_name_pattern=fixed_name_pattern
)


_T0 = TypeVar('_T0')
_T1 = TypeVar('_T1')

def _get_input_names(tree: TfSignature) -> _NamesAndSequence:
return _generate_names(tree, prefix='input')

def optional_map(f: Callable[[_T0], _T1], a: _T0 | None) -> _T1 | None:
if a is None:
return None
return f(a)


def _to_optional_str_list(
seq: Sequence[str] | None,
) -> tf_concrete_function_handle_pb2.StrList | None:
return optional_map(
lambda s: tf_concrete_function_handle_pb2.StrList(elements=s), seq
)
def _get_output_names(tree: TfSignature) -> _NamesAndSequence:
return _generate_names(tree, prefix='output')


# TODO(b/400777413): Remove `fixed_name_pattern` in tf2obm once GemaxProd no
# longer uses it.
def tf_concrete_function_name_to_obm_function(
name: str,
*,
input_signature: TfSignature | None = None,
output_signature: TfSignature | None = None,
fn: tf.types.experimental.ConcreteFunction | None = None,
fixed_name_pattern: bool = False,
) -> obm.SerializableFunction:
"""Converts a TensorFlow (TF) concrete function name (with input/output signatures) to an Orbax Model (OBM) function.

Expand All @@ -153,10 +117,6 @@ def tf_concrete_function_name_to_obm_function(
input_signature: the input signature of the concrete function.
output_signature: the output signature of the concrete function.
fn: the concrete function itself.
fixed_name_pattern: see `to_keyword_only_fn`. If this function is used with
`to_keyword_only_fn`, their `fixed_name_pattern` arguments must match. If
it is used with `save_tf_concrete_functions`, `fixed_name_pattern` should
be set to `False`.

Returns:
An OBM function.
Expand All @@ -175,17 +135,13 @@ def tf_concrete_function_name_to_obm_function(
input_signature = get_input_signature(fn)
output_signature = get_output_signature(fn)

input_names, _ = _get_input_names(
input_signature, fixed_name_pattern=fixed_name_pattern
)
output_names, _ = _get_output_names(
output_signature, fixed_name_pattern=fixed_name_pattern
)
input_names, _ = _get_input_names(input_signature)
output_names, _ = _get_output_names(output_signature)
unstructured_data = obm.manifest_pb2.UnstructuredData(
inlined_bytes=tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(
fn_name=name,
input_names=_to_optional_str_list(input_names),
output_names=_to_optional_str_list(output_names),
input_names=list(input_names),
output_names=list(output_names),
).SerializeToString(),
mime_type=TF_CONCRETE_FUNCTION_HANDLE_MIME_TYPE,
version=TF_CONCRETE_FUNCTION_HANDLE_VERSION,
Expand Down Expand Up @@ -308,16 +264,16 @@ def _to_args_kwargs_pattern(
)


T0 = TypeVar('T0')
T1 = TypeVar('T1')
T2 = TypeVar('T2')


def unzip2(
xys: Iterable[tuple[T1, T2]],
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
xys: Iterable[tuple[T0, T1]],
) -> tuple[tuple[T0, ...], tuple[T1, ...]]:
"""Unzip sequence of length-2 tuples into two tuples."""
xs: list[T1] = []
ys: list[T2] = []
xs: list[T0] = []
ys: list[T1] = []
for x, y in xys:
xs.append(x)
ys.append(y)
Expand All @@ -328,15 +284,11 @@ def unzip2(
# longer uses it.
def to_keyword_only_fn(
f: tf.types.experimental.ConcreteFunction,
*,
fixed_name_pattern: bool = False,
) -> tf.types.experimental.ConcreteFunction:
"""Wraps a function into one whose inputs and outputs are keyword-only.

Args:
f: a TF concrete function.
fixed_name_pattern: if True, the new function's input (output) names will be
in the form "input_0", "input_1", ... (output_0", "output_1", ...).

Returns:
The wrapped function (also a TF concrete function).
Expand All @@ -345,7 +297,7 @@ def to_keyword_only_fn(
output_signature = get_output_signature(f)

def input_names_fn(tree: TfSignature) -> _NamesAndSequence:
names, flat = _get_input_names(tree, fixed_name_pattern=fixed_name_pattern)
names, flat = _get_input_names(tree)
if names is None and is_args_kwargs_pattern(tree):
args, kwargs = tree
if not kwargs and len(args) == 1 and _is_str_tensor_spec_dict(args[0]):
Expand All @@ -365,9 +317,7 @@ def input_names_fn(tree: TfSignature) -> _NamesAndSequence:
new_input_signature, input_names = _make_dict_only_signature(
input_signature, input_names_fn
)
output_names, _ = _get_output_names(
output_signature, fixed_name_pattern=fixed_name_pattern
)
output_names, _ = _get_output_names(output_signature)

if input_names is None and output_names is None:
return f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,21 @@ def test_is_dict_only(self, signature, expected, as_output_signature):
(
"tuple",
_TUPLE,
("_0", "_1", "_2"),
("my_prefix_0", "my_prefix_1", "my_prefix_2"),
),
(
"dict",
_DICT,
None,
("my_prefix_0", "my_prefix_1"),
),
)
)
def test_generate_names(self, signature, expected_names, as_output_signature):
if as_output_signature:
signature = _as_output_signature(signature)
prefix = "my_prefix"
if expected_names is not None:
expected_names = tuple(prefix + name for name in expected_names)
names, _ = _generate_names(
signature, prefix=prefix, fixed_name_pattern=False
)

names, _ = _generate_names(signature, prefix=prefix)
self.assertEqual(
names,
expected_names,
Expand All @@ -154,7 +151,7 @@ def test_generate_names(self, signature, expected_names, as_output_signature):
),
(
((), _DICT),
None,
((), _dict_from_seq("input_", obm.tree_util.flatten(_DICT))),
),
(
(_TUPLE, _DICT),
Expand All @@ -167,7 +164,7 @@ def test_generate_names(self, signature, expected_names, as_output_signature):
),
(
((_DICT,), {}),
((), _DICT),
((), _dict_from_seq("input_", obm.tree_util.flatten(_DICT))),
),
))
for output_case_id, (output_sig, expected_output_sig) in enumerate((
Expand All @@ -181,7 +178,7 @@ def test_generate_names(self, signature, expected_names, as_output_signature):
),
(
_DICT,
None,
(_dict_from_seq("output_", obm.tree_util.flatten(_DICT))),
),
(
(_DICT,),
Expand All @@ -196,10 +193,6 @@ def test_generate_names(self, signature, expected_names, as_output_signature):
def test_to_keyword_only_fn(
self, input_sig, expected_input_sig, output_sig, expected_output_sig
):
if expected_input_sig is None:
expected_input_sig = input_sig
if expected_output_sig is None:
expected_output_sig = output_sig

@tf.function(autograph=False)
def f(*args, **kwargs):
Expand Down Expand Up @@ -503,19 +496,11 @@ def tf_fn(a):
)
with open(os.path.join(save_dir_path, pre_processor_filename), "rb") as f:
pre_processor_proto.ParseFromString(f.read())
expected_pre_processor_proto_text = (
expected_pre_processor_proto_text = f"""
fn_name: "{pre_processor_name_in_tf}"
input_names: "input_0"
output_names: "output_0"
"""
fn_name: \""""
+ pre_processor_name_in_tf
+ """\"
input_names {
elements: "input_0"
}
output_names {
elements: "output_0"
}
"""
)
expected_pre_processor_proto = text_format.Parse(
expected_pre_processor_proto_text,
tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(),
Expand All @@ -531,19 +516,11 @@ def tf_fn(a):
)
with open(os.path.join(save_dir_path, post_processor_filename), "rb") as f:
post_processor_proto.ParseFromString(f.read())
expected_post_processor_proto_text = (
expected_post_processor_proto_text = f"""
fn_name: "{post_processor_name_in_tf}"
input_names: "input_0"
output_names: "output_0"
"""
fn_name: \""""
+ post_processor_name_in_tf
+ """\"
input_names {
elements: "input_0"
}
output_names {
elements: "output_0"
}
"""
)
expected_post_processor_proto = text_format.Parse(
expected_post_processor_proto_text,
tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(),
Expand Down
Loading