Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions choice_learn/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
from_logits=False,
sparse=False,
average_on_batch=False,
average_on_trip=False,
epsilon=1e-10,
name="negative_log_likelihood",
axis=-1,
Expand All @@ -40,7 +40,7 @@ def __init__(
Whether y_true is given as an index or a one-hot, by default False
epsilon : float, optional
Lower bound for log(.), by default 1e-10
average_on_batch: bool, optional
average_on_trip: bool, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring on the following lines (44-45) for this parameter still refers to 'batch'. Please update it to 'trip' for consistency with the parameter name change.

Whether the metric should be averaged over each batch. Typically used to
get metrics averaged by Trip, by default False
name : str, optional
Expand All @@ -53,7 +53,7 @@ def __init__(
self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals")
self.from_logits = from_logits
self.sparse = sparse
self.average_on_batch = average_on_batch
self.average_on_trip = average_on_trip
self.epsilon = epsilon
self.axis = axis

Expand Down Expand Up @@ -90,11 +90,13 @@ def update_state(self, y_true, y_pred, batch=None, sample_weight=None):
y_true * tf.math.log(y_pred) * tf.expand_dims(sample_weight, axis=-1),
axis=self.axis,
)

if batch is not None and self.average_on_batch:
for _, idx in zip(*tf.unique(batch)):
self.nll.assign(self.nll + tf.reduce_mean(nll_value[idx]))
self.n_evals.assign(self.n_evals + 1)
if batch is not None and self.average_on_trip:
unique_trips, segment_ids = tf.unique(batch)
trip_nlls = tf.math.unsorted_segment_mean(
nll_value, segment_ids, tf.shape(unique_trips)[0]
)
self.nll.assign_add(tf.reduce_sum(trip_nlls))
self.n_evals.assign_add(tf.cast(tf.shape(unique_trips)[0], self.n_evals.dtype))
else:
self.nll.assign(self.nll + tf.reduce_sum(nll_value))
if sample_weight is None:
Expand All @@ -118,15 +120,15 @@ class MRR(tf.keras.metrics.Metric):

def __init__(
self,
average_on_batch=False,
average_on_trip=False,
name="mean_reciprocal_rank",
axis=-1,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.mrr = self.add_variable(shape=(), initializer="zeros", name="mrr")
self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals")
self.average_on_batch = average_on_batch
self.average_on_trip = average_on_trip
self.axis = axis

def update_state(
Expand Down Expand Up @@ -156,13 +158,17 @@ def update_state(
[tf.range(len(y_true)), y_true], axis=1
) # Shape: (batch_size, 2)
item_ranks = tf.gather_nd(ranks, item_batch_indices) # Shape: (batch_size,)
mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis)

if batch is not None and self.average_on_batch:
self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank))
self.n_evals.assign(self.n_evals + 1)
# mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code should be removed to improve code clarity.

if batch is not None and self.average_on_trip:
unique_trips, segment_ids = tf.unique(batch)
reciprocal_ranks = 1.0 / tf.cast(item_ranks, dtype=tf.float32)
trip_mrrs = tf.math.unsorted_segment_mean(
reciprocal_ranks, segment_ids, tf.shape(unique_trips)[0]
)
self.mrr.assign_add(tf.reduce_sum(trip_mrrs))
self.n_evals.assign_add(tf.cast(tf.shape(unique_trips)[0], self.n_evals.dtype))
else:
self.mrr.assign(self.mrr + tf.reduce_sum(mean_rank))
self.mrr.assign(self.mrr + tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the if branch and for better readability, consider using assign_add here. It would also be good to apply the same change to the update of self.n_evals on the next line for consistency.

Suggested change
self.mrr.assign(self.mrr + tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32)))
self.mrr.assign_add(tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32)))

self.n_evals.assign(self.n_evals + tf.shape(y_true)[0])

def result(self):
Expand All @@ -181,7 +187,7 @@ class HitRate(tf.keras.metrics.Metric):

def __init__(
self,
average_on_batch=False,
average_on_trip=False,
top_k: int = 10,
name=None,
axis=-1,
Expand All @@ -195,7 +201,7 @@ def __init__(
shape=(), initializer="zeros", name=f"hit_rate_at_{self.top_k}"
)
self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals")
self.average_on_batch = average_on_batch
self.average_on_trip = average_on_trip
self.axis = axis

def update_state(self, y_true, y_pred, batch=None):
Expand Down Expand Up @@ -223,10 +229,12 @@ def update_state(self, y_true, y_pred, batch=None):
),
axis=1,
)
hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis)
if batch is not None and self.average_on_batch:
self.hit_rate.assign(self.hit_rate + tf.reduce_mean(hits))
self.n_evals.assign(self.n_evals + 1)
hits = tf.cast(hits_per_batch, tf.float32)
if batch is not None and self.average_on_trip:
unique_trips, segment_ids = tf.unique(batch)
trip_means = tf.math.unsorted_segment_mean(hits, segment_ids, tf.shape(unique_trips)[0])
self.hit_rate.assign_add(tf.reduce_sum(trip_means))
self.n_evals.assign_add(tf.cast(tf.shape(unique_trips)[0], self.n_evals.dtype))
else:
self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits))
self.n_evals.assign(self.n_evals + tf.shape(y_true)[0])
Expand Down