Skip to content

Enforce that the input and output name are not optional and remove the logic to deal with non input name and output name logic. #1889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions model/orbax/experimental/model/core/python/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from types import NoneType # pylint: disable=g-importing-member
from typing import Any, Callable, Dict, Iterable, Iterator, List, Sequence, Tuple
from typing import TypeVar
from orbax.experimental.model.core.protos import type_pb2


# TODO(wangpeng): Use the `type Tree[T] = T | List[Tree[T]] | ...` syntax once
Expand Down Expand Up @@ -98,6 +99,35 @@ def flatten(tree: Tree[T4]) -> List[T4]:
return [tree]


def flatten_tree(tree: Tree[T4], key: str) -> List[Tuple[str, T4]]:
"""Flattens a tree to a list of (key, leaf) pairs.

Args:
tree: the tree to flatten.
key: the key of the current node.

Returns:
A list of (key, leaf) pairs.
"""
if isinstance(tree, (tuple, list)):
return flatten_lists(list(flatten_tree(x, key) for x in tree))
elif isinstance(tree, dict):
tree: Dict[str, Tree[T4]]
# Sort by key order (as opposed to insertion order like `OrderedDict`).
return flatten_lists(
list(flatten_tree(v, k) for k, v in sorted(tree.items()))
)
elif tree is None:
return []
# TODO: b/406087384 - Add support for StringPairs case.
elif isinstance(tree, type_pb2.StringTypePairs):
raise NotImplementedError(
"StringTypePairs is not supported by flatten_tree yet."
)
else:
return [(key, tree)]


T5 = TypeVar("T5")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@ message TfConcreteFunctionHandle {
// SavedModel) but we want to allow calling the function with positional
// arguments.
//
// If this field is not set, this function can only be called with
// keyword-only arguments.
optional StrList input_names = 20;
repeated string input_names = 20;

// Similar to `input_names`, but for outputs.
optional StrList output_names = 30;
}

message StrList {
repeated string elements = 1;
repeated string output_names = 30;
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ def _is_dict_only(tree: TfSignature) -> bool:
def _generate_names(
tree: TfSignature, prefix: str = '', *, fixed_name_pattern: bool
) -> _NamesAndSequence:
"""Generates names for a TF signature tree."""
if not fixed_name_pattern and _is_dict_only(tree):
# If the input signature is dict-only, the function shouldn't be
# called with positional arguments anyway, so we don't generate
# names and just return None.
return None, None
# If the input signature is dict-only, the function will generate a
# keyword-based version of the function, and the names will be the key of
# the dict.
flat_pair = obm.tree_util.flatten_tree(tree, '')
return tuple(f'{pair[0]}' for pair in flat_pair), tuple(
v for _, v in flat_pair
)
flat = obm.tree_util.flatten(tree)
return tuple(f'{prefix}_{i}' for i in range(len(flat))), flat

Expand All @@ -111,22 +115,10 @@ def _get_output_names(
)


_T0 = TypeVar('_T0')
_T1 = TypeVar('_T1')


def optional_map(f: Callable[[_T0], _T1], a: _T0 | None) -> _T1 | None:
if a is None:
return None
return f(a)


def _to_optional_str_list(
seq: Sequence[str] | None,
) -> tf_concrete_function_handle_pb2.StrList | None:
return optional_map(
lambda s: tf_concrete_function_handle_pb2.StrList(elements=s), seq
)
) -> list[str] | None:
return [x for x in seq if x is not None]


# TODO(b/400777413): Remove `fixed_name_pattern` in tf2obm once GemaxProd no
Expand Down Expand Up @@ -308,16 +300,16 @@ def _to_args_kwargs_pattern(
)


T0 = TypeVar('T0')
T1 = TypeVar('T1')
T2 = TypeVar('T2')


def unzip2(
xys: Iterable[tuple[T1, T2]],
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
xys: Iterable[tuple[T0, T1]],
) -> tuple[tuple[T0, ...], tuple[T1, ...]]:
"""Unzip sequence of length-2 tuples into two tuples."""
xs: list[T1] = []
ys: list[T2] = []
xs: list[T0] = []
ys: list[T1] = []
for x, y in xys:
xs.append(x)
ys.append(y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ def test_is_dict_only(self, signature, expected, as_output_signature):
(
"dict",
_DICT,
None,
("a", "b"),
),
)
)
def test_generate_names(self, signature, expected_names, as_output_signature):
if as_output_signature:
signature = _as_output_signature(signature)
prefix = "my_prefix"
if expected_names is not None:
if isinstance(signature, tuple):
expected_names = tuple(prefix + name for name in expected_names)
names, _ = _generate_names(
signature, prefix=prefix, fixed_name_pattern=False
Expand Down Expand Up @@ -508,12 +508,8 @@ def tf_fn(a):
fn_name: \""""
+ pre_processor_name_in_tf
+ """\"
input_names {
elements: "input_0"
}
output_names {
elements: "output_0"
}
input_names: "input_0"
output_names: "output_0"
"""
)
expected_pre_processor_proto = text_format.Parse(
Expand All @@ -536,12 +532,8 @@ def tf_fn(a):
fn_name: \""""
+ post_processor_name_in_tf
+ """\"
input_names {
elements: "input_0"
}
output_names {
elements: "output_0"
}
input_names: "input_0"
output_names: "output_0"
"""
)
expected_post_processor_proto = text_format.Parse(
Expand Down
Loading