Skip to content
This repository was archived by the owner on Sep 27, 2024. It is now read-only.

Commit 87c0b4d

Browse files
authored
Merge pull request #223 from casassg/casassg/multi-output-model
Add support for multiple output models
2 parents 8e45afb + e0e85c9 commit 87c0b4d

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

model_card_toolkit/utils/tfx_util.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -442,31 +442,42 @@ def _parse_array_value(array: Dict[str, Any]) -> str:
442442
logging.warning('Received unexpected array %s', str(array))
443443
return ''
444444

445-
for slice_repr, metrics_for_slice in (
446-
eval_result.get_metrics_for_all_slices().items()):
447-
# Parse the slice name
448-
if not isinstance(slice_repr, tuple):
449-
raise ValueError(
450-
f'Expected EvalResult slices to be tuples; found {type(slice_repr)}')
451-
slice_name = '_X_'.join(f'{a}_{b}' for a, b in slice_repr)
452-
for metric_name, metric_value in metrics_for_slice.items():
453-
# Parse the metric value
454-
parsed_value = ''
455-
if 'doubleValue' in metric_value:
456-
parsed_value = metric_value['doubleValue']
457-
elif 'boundedValue' in metric_value:
458-
parsed_value = metric_value['boundedValue']['value']
459-
elif 'arrayValue' in metric_value:
460-
parsed_value = _parse_array_value(metric_value['arrayValue'])
461-
else:
462-
logging.warning(
463-
'Expected doubleValue, boundedValue, or arrayValue; found %s',
464-
metric_value.keys())
465-
if parsed_value:
466-
# Create the PerformanceMetric and append to the ModelCard
467-
metric = model_card_module.PerformanceMetric(
468-
type=metric_name, value=str(parsed_value), slice=slice_name)
469-
model_card.quantitative_analysis.performance_metrics.append(metric)
445+
# NOTE: When multiple outputs are passed, each will be in it's own output_name key
446+
# If that's the case add each output_name + metric to the quantitative_analysis by namespacing by
447+
# output_name.metric to distinguish them
448+
output_names = set()
449+
for slicing_metric in eval_result.slicing_metrics:
450+
for output_name in slicing_metric[1]:
451+
output_names.add(output_name)
452+
for output_name in sorted(output_names):
453+
for slice_repr, metrics_for_slice in (
454+
eval_result.get_metrics_for_all_slices(output_name=output_name).items()):
455+
# Parse the slice name
456+
if not isinstance(slice_repr, tuple):
457+
raise ValueError(
458+
f'Expected EvalResult slices to be tuples; found {type(slice_repr)}')
459+
slice_name = '_X_'.join(f'{a}_{b}' for a, b in slice_repr)
460+
for metric_name, metric_value in metrics_for_slice.items():
461+
# Parse the metric value
462+
parsed_value = ''
463+
if 'doubleValue' in metric_value:
464+
parsed_value = metric_value['doubleValue']
465+
elif 'boundedValue' in metric_value:
466+
parsed_value = metric_value['boundedValue']['value']
467+
elif 'arrayValue' in metric_value:
468+
parsed_value = _parse_array_value(metric_value['arrayValue'])
469+
else:
470+
logging.warning(
471+
'Expected doubleValue, boundedValue, or arrayValue; found %s',
472+
metric_value.keys())
473+
if parsed_value:
474+
metric_type = metric_name
475+
if output_name:
476+
metric_type = f"{output_name}.{metric_name}"
477+
# Create the PerformanceMetric and append to the ModelCard
478+
metric = model_card_module.PerformanceMetric(
479+
type=metric_type, value=str(parsed_value), slice=slice_name)
480+
model_card.quantitative_analysis.performance_metrics.append(metric)
470481

471482

472483
def filter_metrics(

0 commit comments

Comments
 (0)