Skip to content

Commit 655c54b

Browse files
committed
Update type hints to account for Beam handling of Sequences
Improvements to Beam's type hinting infrastructure found a breakage in this code based on mismatched type hints (stemming from the one-way relationships between lists, sequences, and iterables.) CoGroupByKey outputs Iterables, not Sequences, but these type checks were errantly passing before. PiperOrigin-RevId: 733394393
1 parent 7d9f283 commit 655c54b

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tensorflow_data_validation/statistics/generators/lift_stats_generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ def _make_y_rates(
434434

435435

436436
def _compute_lifts(
437-
join_info: Tuple[_SlicedYKey, Dict[Text, Sequence[Any]]]
437+
join_info: Tuple[_SlicedYKey, Dict[Text, Iterable[Any]]]
438438
# TODO(b/147153346) update dict value list element type annotation to:
439439
# Sequence[Union[_YRate, _ConditionalYRate]]
440440
) -> Iterator[Tuple[_SlicedFeatureKey, _LiftInfo]]:
@@ -461,7 +461,8 @@ def _compute_lifts(
461461
_LiftInfo(x, y, lift, xy_count, x_count, y_count)).
462462
"""
463463
(slice_key, y), join_inputs = join_info
464-
y_rate = join_inputs['y_rate'][0]
464+
# coerce iterable to list for __getitem__
465+
y_rate = list(join_inputs['y_rate'])[0]
465466
for conditional_y_rate in join_inputs['conditional_y_rate']:
466467
lift = ((float(conditional_y_rate.xy_count) / conditional_y_rate.x_count) /
467468
(float(y_rate.y_count) / y_rate.example_count))

0 commit comments

Comments
 (0)