Skip to content

Commit

Permalink
[Metrics] Fixed metrics with no required tags currently being blackho…
Browse files Browse the repository at this point in the history
…le'd (ray-project#36298)

Currently, Metric implementation requires you to specify tag_keys of the tags that have to be recorded. However it also does NOT allow you to record additional tags on top of the ones you have provided during creation.

This is problematic since it couples metric creation with its usage: you can't, say add new tags to metric usage w/o updating its definition (which might be in a separate module, library, etc)

One straightforward example: if you haven't specified any tag_keys during metric creation it's impossible for you to add any tags during recording.

Proposal:

Treat tag_keys as only a set of required tags (ie allow additional tags to be specified during recording)
  • Loading branch information
alexeykudinkin authored Feb 7, 2024
1 parent ccfcb21 commit 63e1586
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 26 deletions.
67 changes: 58 additions & 9 deletions python/ray/tests/test_metrics_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,20 @@ def _setup_cluster_for_test(request, ray_start_cluster):

worker_should_exit = SignalActor.remote()

extra_tags = {"ray_version": ray.__version__}

# Generate a metric in the driver.
counter = Counter("test_driver_counter", description="desc")
counter.inc()
counter.inc(tags=extra_tags)

# Generate some metrics from actor & tasks.
@ray.remote
def f():
counter = Counter("test_counter", description="desc")
counter.inc()
counter = ray.get(ray.put(counter)) # Test serialization.
counter.inc()
counter.inc(2)
counter.inc(tags=extra_tags)
counter.inc(2, tags=extra_tags)
ray.get(worker_should_exit.wait.remote())

# Generate some metrics for the placement group.
Expand All @@ -206,7 +208,7 @@ async def ping(self):
"test_histogram", description="desc", boundaries=[0.1, 1.6]
)
histogram = ray.get(ray.put(histogram)) # Test serialization.
histogram.observe(1.5)
histogram.observe(1.5, tags=extra_tags)
ray.get(worker_should_exit.wait.remote())

a = A.remote()
Expand Down Expand Up @@ -730,6 +732,58 @@ def test_basic_custom_metrics(metric_mock):
metric_mock.record.assert_called_with(8, tags=tags)


def test_custom_metrics_with_extra_tags(metric_mock):
base_tags = {"a": "1"}
extra_tags = {"a": "1", "b": "2"}

# -- Counter --
count = Counter("count", tag_keys=("a",))
with pytest.raises(ValueError):
count.inc(1)

count._metric = metric_mock

# Increment with base tags
count.inc(1, tags=base_tags)
metric_mock.record.assert_called_with(1, tags=base_tags)
metric_mock.reset_mock()

# Increment with extra tags
count.inc(1, tags=extra_tags)
metric_mock.record.assert_called_with(1, tags=extra_tags)
metric_mock.reset_mock()

# -- Gauge --
gauge = Gauge("gauge", description="gauge", tag_keys=("a",))
gauge._metric = metric_mock

# Record with base tags
gauge.record(4, tags=base_tags)
metric_mock.record.assert_called_with(4, tags=base_tags)
metric_mock.reset_mock()

# Record with extra tags
gauge.record(4, tags=extra_tags)
metric_mock.record.assert_called_with(4, tags=extra_tags)
metric_mock.reset_mock()

# -- Histogram
histogram = Histogram(
"hist", description="hist", boundaries=[1.0, 3.0], tag_keys=("a",)
)
histogram._metric = metric_mock

# Record with base tags
histogram.observe(8, tags=base_tags)
metric_mock.record.assert_called_with(8, tags=base_tags)
metric_mock.reset_mock()

# Record with extra tags
histogram.observe(8, tags=extra_tags)
metric_mock.record.assert_called_with(8, tags=extra_tags)
metric_mock.reset_mock()


def test_custom_metrics_info(metric_mock):
# Make sure .info public method works.
histogram = Histogram(
Expand Down Expand Up @@ -821,11 +875,6 @@ def test_custom_metrics_validation(shutdown_only):
with pytest.raises(ValueError):
metric.inc(1.0, {"a": "2"})

# Extra tag not in tag_keys.
metric = Counter("name", tag_keys=("a",))
with pytest.raises(ValueError):
metric.inc(1.0, {"a": "1", "b": "2"})

# tag_keys must be tuple.
with pytest.raises(TypeError):
Counter("name", tag_keys="a")
Expand Down
40 changes: 23 additions & 17 deletions python/ray/util/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def set_default_tags(self, default_tags: Dict[str, str]):
return self

def record(
self, value: Union[int, float], tags: Dict[str, str] = None, _internal=False
self,
value: Union[int, float],
tags: Optional[Dict[str, str]] = None,
_internal=False,
) -> None:
"""Record the metric point of the metric.
Expand Down Expand Up @@ -110,26 +113,29 @@ def record(
"instead."
)

if tags is not None:
for val in tags.values():
if not isinstance(val, str):
raise TypeError(f"Tag values must be str, got {type(val)}.")
final_tags = self._get_final_tags(tags)
self._validate_tags(final_tags)
self._metric.record(value, tags=final_tags)

def _get_final_tags(self, tags):
if not tags:
return self._default_tags

final_tags = {}
tags_copy = tags.copy() if tags else {}
for val in tags.values():
if not isinstance(val, str):
raise TypeError(f"Tag values must be str, got {type(val)}.")

return {**self._default_tags, **tags}

def _validate_tags(self, final_tags):
missing_tags = []
for tag_key in self._tag_keys:
# Prefer passed tags over default tags.
if tags is not None and tag_key in tags:
final_tags[tag_key] = tags_copy.pop(tag_key)
elif tag_key in self._default_tags:
final_tags[tag_key] = self._default_tags[tag_key]
else:
raise ValueError(f"Missing value for tag key {tag_key}.")
if tag_key not in final_tags:
missing_tags.append(tag_key)

if len(tags_copy) > 0:
raise ValueError(f"Unrecognized tag keys: {list(tags_copy.keys())}.")

self._metric.record(value, tags=final_tags)
if missing_tags:
raise ValueError(f"Missing value for tag key(s): {','.join(missing_tags)}.")

@property
def info(self) -> Dict[str, Any]:
Expand Down

0 comments on commit 63e1586

Please sign in to comment.