Skip to content

Commit

Permalink
Fix the bug in fairness indicators due to deprecation of tf estimator…
Browse files Browse the repository at this point in the history
… and feature column utils.

PiperOrigin-RevId: 627540407
  • Loading branch information
zhouhao138 authored and Responsible ML Infra Team committed Apr 23, 2024
1 parent 8f2dc52 commit 0e14d91
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 143 deletions.
2 changes: 1 addition & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Current Version (Still in Development)

## Major Features and Improvements

Update example model to use Keras models instead of estimators.
## Bug Fixes and Other Changes

* Deprecated python 3.8 support
Expand Down
170 changes: 65 additions & 105 deletions fairness_indicators/example_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,132 +14,92 @@
# ==============================================================================
"""Demo script to train and evaluate a model.
This scripts contains boilerplate code to train a DNNClassifier
This scripts contains boilerplate code to train a Keras Text Classifier
and evaluate it using Tensorflow Model Analysis. Evaluation
results can be visualized using tools like TensorBoard.
Usage:
1. Train model:
demo_script.train_model(...)
2. Evaluate:
demo_script.evaluate_model(...)
"""

import os
import tempfile
from tensorflow import keras
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
import tensorflow_hub as hub
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.addons.fairness.post_export_metrics import fairness_indicators # pylint: disable=unused-import


def train_model(model_dir,
train_tf_file,
label,
text_feature,
feature_map,
module_spec='https://tfhub.dev/google/nnlm-en-dim128/1'):
"""Train model using DNN Classifier.
Args:
model_dir: Directory path to save trained model.
train_tf_file: File containing training TFRecordDataset.
label: Groundtruth label.
text_feature: Text feature to be evaluated.
feature_map: Dict of feature names to their data type.
module_spec: A module spec defining the module to instantiate or a path
where to load a module spec.
Returns:
Trained DNNClassifier.
"""

def train_input_fn():
"""Train Input function."""

def parse_function(serialized):
parsed_example = tf.io.parse_single_example(
serialized=serialized, features=feature_map)
# Adds a weight column to deal with unbalanced classes.
parsed_example['weight'] = tf.add(parsed_example[label], 0.1)
return (parsed_example, parsed_example[label])

train_dataset = tf.data.TFRecordDataset(
filenames=[train_tf_file]).map(parse_function).batch(512)
return train_dataset
TEXT_FEATURE = 'comment_text'
LABEL = 'toxicity'
SLICE = 'slice'
FEATURE_MAP = {
LABEL: tf.io.FixedLenFeature([], tf.float32),
TEXT_FEATURE: tf.io.FixedLenFeature([], tf.string),
SLICE: tf.io.VarLenFeature(tf.string),
}

text_embedding_column = hub.text_embedding_column(
key=text_feature, module_spec=module_spec)

classifier = tf_estimator.DNNClassifier(
hidden_units=[500, 100],
weight_column='weight',
feature_columns=[text_embedding_column],
n_classes=2,
optimizer=tf.train.AdagradOptimizer(learning_rate=0.003),
model_dir=model_dir)
class ExampleParser(keras.layers.Layer):
"""A Keras layer that parses the tf.Example."""

classifier.train(input_fn=train_input_fn, steps=1000)
return classifier
def __init__(self, input_feature_key):
self._input_feature_key = input_feature_key
super().__init__()


def evaluate_model(classifier, validate_tf_file, tfma_eval_result_path,
selected_slice, label, feature_map):
def call(self, serialized_examples):
def get_feature(serialized_example):
parsed_example = tf.io.parse_single_example(
serialized_example, features=FEATURE_MAP
)
return parsed_example[self._input_feature_key]

return tf.map_fn(get_feature, serialized_examples)


class ExampleModel(keras.Model):
"""A Example Keras NLP model."""

def __init__(self, input_feature_key):
super().__init__()
self.parser = ExampleParser(input_feature_key)
self.text_vectorization = keras.layers.TextVectorization(
max_tokens=32,
output_mode='int',
output_sequence_length=32,
)
self.text_vectorization.adapt(
['nontoxic', 'toxic comment', 'test comment', 'abc', 'abcdef', 'random']
)
self.dense1 = keras.layers.Dense(32, activation='relu')
self.dense2 = keras.layers.Dense(1)

def call(self, inputs, training=True, mask=None):
parsed_example = self.parser(inputs)
text_vector = self.text_vectorization(parsed_example)
output1 = self.dense1(tf.cast(text_vector, tf.float32))
output2 = self.dense2(output1)
return output2


def evaluate_model(
classifier_model_path,
validate_tf_file_path,
tfma_eval_result_path,
eval_config,
):
"""Evaluate Model using Tensorflow Model Analysis.
Args:
classifier: Trained classifier model to be evaluted.
validate_tf_file: File containing validation TFRecordDataset.
tfma_eval_result_path: Directory path where eval results will be written.
selected_slice: Feature for slicing the data.
label: Groundtruth label.
feature_map: Dict of feature names to their data type.
classifier_model_path: Trained classifier model to be evaluted.
validate_tf_file_path: File containing validation TFRecordDataset.
tfma_eval_result_path: Path to export tfma-related eval path.
eval_config: tfma eval_config.
"""

def eval_input_receiver_fn():
"""Eval Input Receiver function."""
serialized_tf_example = tf.compat.v1.placeholder(
dtype=tf.string, shape=[None], name='input_example_placeholder')

receiver_tensors = {'examples': serialized_tf_example}

features = tf.io.parse_example(serialized_tf_example, feature_map)
features['weight'] = tf.ones_like(features[label])

return tfma.export.EvalInputReceiver(
features=features,
receiver_tensors=receiver_tensors,
labels=features[label])

tfma_export_dir = tfma.export.export_eval_savedmodel(
estimator=classifier,
export_dir_base=os.path.join(tempfile.gettempdir(), 'tfma_eval_model'),
eval_input_receiver_fn=eval_input_receiver_fn)

# Define slices that you want the evaluation to run on.
slice_spec = [
tfma.slicer.SingleSliceSpec(), # Overall slice
tfma.slicer.SingleSliceSpec(columns=[selected_slice]),
]

# Add the fairness metrics.
# pytype: disable=module-attr
add_metrics_callbacks = [
tfma.post_export_metrics.fairness_indicators(
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
]
# pytype: enable=module-attr

eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=tfma_export_dir,
add_metrics_callbacks=add_metrics_callbacks)
eval_saved_model_path=classifier_model_path, eval_config=eval_config
)

# Run the fairness evaluation.
tfma.run_model_analysis(
eval_shared_model=eval_shared_model,
data_location=validate_tf_file,
data_location=validate_tf_file_path,
output_path=tfma_eval_result_path,
slice_spec=slice_spec)
eval_config=eval_config,
)
121 changes: 84 additions & 37 deletions fairness_indicators/example_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,17 @@
import datetime
import os
import tempfile

from fairness_indicators import example_model
import numpy as np
import six
from tensorflow import keras
import tensorflow.compat.v1 as tf
import tensorflow_model_analysis as tfma
from tensorflow_model_analysis.slicer import slicer_lib as slicer

tf.compat.v1.enable_eager_execution()
from google.protobuf import text_format

TEXT_FEATURE = 'comment_text'
LABEL = 'toxicity'
SLICE = 'slice'
FEATURE_MAP = {
LABEL: tf.io.FixedLenFeature([], tf.float32),
TEXT_FEATURE: tf.io.FixedLenFeature([], tf.string),
SLICE: tf.io.VarLenFeature(tf.string),
}
tf.compat.v1.enable_eager_execution()


class ExampleModelTest(tf.test.TestCase):
Expand All @@ -51,13 +46,13 @@ def setUp(self):

def _create_example(self, comment_text, label, slice_value):
example = tf.train.Example()
example.features.feature[TEXT_FEATURE].bytes_list.value[:] = [
example.features.feature[example_model.TEXT_FEATURE].bytes_list.value[:] = [
six.ensure_binary(comment_text, 'utf8')
]
example.features.feature[SLICE].bytes_list.value[:] = [
example.features.feature[example_model.SLICE].bytes_list.value[:] = [
six.ensure_binary(slice_value, 'utf8')
]
example.features.feature[LABEL].float_list.value[:] = [label]
example.features.feature[example_model.LABEL].float_list.value[:] = [label]
return example

def _create_data(self):
Expand Down Expand Up @@ -85,34 +80,86 @@ def _write_tf_records(self, examples):
return data_location

def test_example_model(self):
train_tf_file = self._write_tf_records(self._create_data())
classifier = example_model.train_model(self._model_dir, train_tf_file,
LABEL, TEXT_FEATURE, FEATURE_MAP)

validate_tf_file = self._write_tf_records(self._create_data())
data = self._create_data()
classifier = example_model.ExampleModel(example_model.TEXT_FEATURE)
classifier.compile(optimizer=keras.optimizers.Adam(), loss='mse')
print([e.SerializeToString() for e in data])
classifier.predict(tf.constant([e.SerializeToString() for e in data]))
classifier.fit(
tf.constant([e.SerializeToString() for e in data]),
np.array([
e.features.feature[example_model.LABEL].float_list.value[:][0]
for e in data
]),
)
classifier.save(self._model_dir, save_format='tf')

eval_config = text_format.Parse(
"""
model_specs {
signature_name: "serving_default"
prediction_key: "predictions" # placeholder
label_key: "toxicity" # placeholder
}
slicing_specs {}
slicing_specs {
feature_keys: ["slice"]
}
metrics_specs {
metrics {
class_name: "ExampleCount"
}
metrics {
class_name: "FairnessIndicators"
}
}
""",
tfma.EvalConfig(),
)

validate_tf_file_path = self._write_tf_records(data)
tfma_eval_result_path = os.path.join(self._model_dir, 'tfma_eval_result')
example_model.evaluate_model(classifier, validate_tf_file,
tfma_eval_result_path, SLICE, LABEL,
FEATURE_MAP)
example_model.evaluate_model(
self._model_dir,
validate_tf_file_path,
tfma_eval_result_path,
eval_config,
)

expected_slice_keys = [
'Overall', 'slice:slice3', 'slice:slice1', 'slice:slice2'
]
evaluation_results = tfma.load_eval_result(tfma_eval_result_path)

self.assertLen(evaluation_results.slicing_metrics, 4)

# Verify if false_positive_rate metrics are computed for all values of
# slice.
for (slice_key, metric_value) in evaluation_results.slicing_metrics:
slice_key = slicer.stringify_slice_key(slice_key)
self.assertIn(slice_key, expected_slice_keys)
self.assertGreaterEqual(
1.0, metric_value['']['']
['post_export_metrics/false_positive_rate@0.50']['doubleValue'])
self.assertLessEqual(
0.0, metric_value['']['']
['post_export_metrics/false_positive_rate@0.50']['doubleValue'])
expected_slice_keys = [
(),
(('slice', 'slice1'),),
(('slice', 'slice2'),),
(('slice', 'slice3'),),
]
slice_keys = [
slice_key for slice_key, _ in evaluation_results.slicing_metrics
]
self.assertEqual(set(expected_slice_keys), set(slice_keys))
# Verify part of the metrics of fairness indicators
metric_values = dict(evaluation_results.slicing_metrics)[(
('slice', 'slice1'),
)]['']['']
self.assertEqual(metric_values['example_count'], {'doubleValue': 5.0})

self.assertEqual(
metric_values['fairness_indicators_metrics/false_positive_rate@0.1'],
{'doubleValue': 0.0},
)
self.assertEqual(
metric_values['fairness_indicators_metrics/false_negative_rate@0.1'],
{'doubleValue': 1.0},
)
self.assertEqual(
metric_values['fairness_indicators_metrics/true_positive_rate@0.1'],
{'doubleValue': 0.0},
)
self.assertEqual(
metric_values['fairness_indicators_metrics/true_negative_rate@0.1'],
{'doubleValue': 1.0},
)


if __name__ == '__main__':
Expand Down

0 comments on commit 0e14d91

Please sign in to comment.