Skip to content

Commit 2a92bca

Browse files
Use exported streamz in benchmark and add environment variable for specifying streamz to use as metrics
PiperOrigin-RevId: 526673279
1 parent cf09b92 commit 2a92bca

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

official/benchmark/perfzero_benchmark.py

+39
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,18 @@
2424
from absl.testing import flagsaver
2525
import tensorflow as tf
2626

27+
from tensorflow.python import pywrap_tfe_monitoring_reader as monitoring
28+
2729
FLAGS = flags.FLAGS
2830

31+
# Environment variable that defines extra metrics to include based on streamz.
32+
# Is a comma separated list of streamz metrics which will result in metrics
33+
# added to the report where the name of the metric is the basename of the
34+
# streamz.
35+
# For example: "/tensorflow/core/tf_function/graph_building_time_usecs"
36+
# would add one new metric named "graph_building_time_usecs"
37+
TEST_BENCHMARK_STREAMZ_EXTRA_METRICS = 'BENCHMARK_STREAMZ_EXTRA_METRICS'
38+
2939

3040
class PerfZeroBenchmark(tf.test.Benchmark):
3141
"""Common methods used in PerfZero Benchmarks.
@@ -75,6 +85,14 @@ def __init__(self,
7585

7686
logging.info('root_data_dir: %s', root_data_dir)
7787

88+
self.extra_metrics = self._get_extra_metrics_readers()
89+
logging.info('including extra streamz metrics: %s', self.extra_metrics)
90+
91+
def report_benchmark(self, metrics=None, **kwargs):
92+
"""Report a benchmark."""
93+
metrics = self._read_extra_metrics() + (metrics or [])
94+
super().report_benchmark(metrics=metrics, **kwargs)
95+
7896
@property
7997
def tpu(self):
8098
return self.default_flags.get('tpu', None)
@@ -83,6 +101,27 @@ def _get_model_dir(self, folder_name):
83101
"""Returns directory to store info, e.g. saved model and event log."""
84102
return os.path.join(self.output_dir, folder_name)
85103

104+
def _get_extra_metrics_readers(self):
105+
streamz = os.environ.get(TEST_BENCHMARK_STREAMZ_EXTRA_METRICS, None)
106+
if streamz:
107+
return [self._get_metric_reader(x) for x in streamz.split(',')]
108+
return []
109+
110+
def _get_metric_reader(self, streamz):
111+
return {
112+
'name': os.path.basename(streamz),
113+
'reader': monitoring.TFE_MonitoringNewCounterReader(streamz),
114+
}
115+
116+
def _read_extra_metrics(self):
117+
return [self._read_extra_metric(metric) for metric in self.extra_metrics]
118+
119+
def _read_extra_metric(self, metric):
120+
return {
121+
'name': metric['name'],
122+
'value': monitoring.TFE_MonitoringReadCounter0(metric['reader']),
123+
}
124+
86125
def _setup(self):
87126
"""Sets up and resets flags before each test."""
88127
logging.set_verbosity(logging.INFO)

0 commit comments

Comments
 (0)