-
Notifications
You must be signed in to change notification settings - Fork 12
Replace average_on_batch by average_on_trip #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -24,7 +24,7 @@ def __init__( | |||||
| self, | ||||||
| from_logits=False, | ||||||
| sparse=False, | ||||||
| average_on_batch=False, | ||||||
| average_on_trip=False, | ||||||
| epsilon=1e-10, | ||||||
| name="negative_log_likelihood", | ||||||
| axis=-1, | ||||||
|
|
@@ -40,7 +40,7 @@ def __init__( | |||||
| Whether y_true is given as an index or a one-hot, by default False | ||||||
| epsilon : float, optional | ||||||
| Lower bound for log(.), by default 1e-10 | ||||||
| average_on_batch: bool, optional | ||||||
| average_on_trip: bool, optional | ||||||
| Whether the metric should be averaged over each batch. Typically used to | ||||||
| get metrics averaged by Trip, by default False | ||||||
| name : str, optional | ||||||
|
|
@@ -53,7 +53,7 @@ def __init__( | |||||
| self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") | ||||||
| self.from_logits = from_logits | ||||||
| self.sparse = sparse | ||||||
| self.average_on_batch = average_on_batch | ||||||
| self.average_on_trip = average_on_trip | ||||||
| self.epsilon = epsilon | ||||||
| self.axis = axis | ||||||
|
|
||||||
|
|
@@ -90,11 +90,13 @@ def update_state(self, y_true, y_pred, batch=None, sample_weight=None): | |||||
| y_true * tf.math.log(y_pred) * tf.expand_dims(sample_weight, axis=-1), | ||||||
| axis=self.axis, | ||||||
| ) | ||||||
|
|
||||||
| if batch is not None and self.average_on_batch: | ||||||
| for _, idx in zip(*tf.unique(batch)): | ||||||
| self.nll.assign(self.nll + tf.reduce_mean(nll_value[idx])) | ||||||
| self.n_evals.assign(self.n_evals + 1) | ||||||
| if batch is not None and self.average_on_trip: | ||||||
| unique_trips, segment_ids = tf.unique(batch) | ||||||
| trip_nlls = tf.math.unsorted_segment_mean( | ||||||
| nll_value, segment_ids, tf.shape(unique_trips)[0] | ||||||
| ) | ||||||
| self.nll.assign_add(tf.reduce_sum(trip_nlls)) | ||||||
| self.n_evals.assign_add(tf.cast(tf.shape(unique_trips)[0], self.n_evals.dtype)) | ||||||
| else: | ||||||
| self.nll.assign(self.nll + tf.reduce_sum(nll_value)) | ||||||
| if sample_weight is None: | ||||||
|
|
@@ -118,15 +120,15 @@ class MRR(tf.keras.metrics.Metric): | |||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| average_on_batch=False, | ||||||
| average_on_trip=False, | ||||||
| name="mean_reciprocal_rank", | ||||||
| axis=-1, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| super().__init__(name=name, **kwargs) | ||||||
| self.mrr = self.add_variable(shape=(), initializer="zeros", name="mrr") | ||||||
| self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") | ||||||
| self.average_on_batch = average_on_batch | ||||||
| self.average_on_trip = average_on_trip | ||||||
| self.axis = axis | ||||||
|
|
||||||
| def update_state( | ||||||
|
|
@@ -156,13 +158,17 @@ def update_state( | |||||
| [tf.range(len(y_true)), y_true], axis=1 | ||||||
| ) # Shape: (batch_size, 2) | ||||||
| item_ranks = tf.gather_nd(ranks, item_batch_indices) # Shape: (batch_size,) | ||||||
| mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis) | ||||||
|
|
||||||
| if batch is not None and self.average_on_batch: | ||||||
| self.mrr.assign(self.mrr + tf.reduce_mean(mean_rank)) | ||||||
| self.n_evals.assign(self.n_evals + 1) | ||||||
| # mean_rank = tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32), axis=self.axis) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| if batch is not None and self.average_on_trip: | ||||||
| unique_trips, segment_ids = tf.unique(batch) | ||||||
| reciprocal_ranks = 1.0 / tf.cast(item_ranks, dtype=tf.float32) | ||||||
| trip_mrrs = tf.math.unsorted_segment_mean( | ||||||
| reciprocal_ranks, segment_ids, tf.shape(unique_trips)[0] | ||||||
| ) | ||||||
| self.mrr.assign_add(tf.reduce_sum(trip_mrrs)) | ||||||
| self.n_evals.assign_add(tf.cast(tf.shape(unique_trips)[0], self.n_evals.dtype)) | ||||||
| else: | ||||||
| self.mrr.assign(self.mrr + tf.reduce_sum(mean_rank)) | ||||||
| self.mrr.assign(self.mrr + tf.reduce_sum(tf.cast(1 / item_ranks, dtype=tf.float32))) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency with the
Suggested change
|
||||||
| self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) | ||||||
|
|
||||||
| def result(self): | ||||||
|
|
@@ -181,7 +187,7 @@ class HitRate(tf.keras.metrics.Metric): | |||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| average_on_batch=False, | ||||||
| average_on_trip=False, | ||||||
| top_k: int = 10, | ||||||
| name=None, | ||||||
| axis=-1, | ||||||
|
|
@@ -195,7 +201,7 @@ def __init__( | |||||
| shape=(), initializer="zeros", name=f"hit_rate_at_{self.top_k}" | ||||||
| ) | ||||||
| self.n_evals = self.add_variable(shape=(), initializer="zeros", name="n_evals") | ||||||
| self.average_on_batch = average_on_batch | ||||||
| self.average_on_trip = average_on_trip | ||||||
| self.axis = axis | ||||||
|
|
||||||
| def update_state(self, y_true, y_pred, batch=None): | ||||||
|
|
@@ -223,10 +229,12 @@ def update_state(self, y_true, y_pred, batch=None): | |||||
| ), | ||||||
| axis=1, | ||||||
| ) | ||||||
| hits = tf.reduce_sum(tf.cast(hits_per_batch, tf.float32), axis=self.axis) | ||||||
| if batch is not None and self.average_on_batch: | ||||||
| self.hit_rate.assign(self.hit_rate + tf.reduce_mean(hits)) | ||||||
| self.n_evals.assign(self.n_evals + 1) | ||||||
| hits = tf.cast(hits_per_batch, tf.float32) | ||||||
| if batch is not None and self.average_on_trip: | ||||||
| unique_trips, segment_ids = tf.unique(batch) | ||||||
| trip_means = tf.math.unsorted_segment_mean(hits, segment_ids, tf.shape(unique_trips)[0]) | ||||||
| self.hit_rate.assign_add(tf.reduce_sum(trip_means)) | ||||||
| self.n_evals.assign_add(tf.cast(tf.shape(unique_trips)[0], self.n_evals.dtype)) | ||||||
| else: | ||||||
| self.hit_rate.assign(self.hit_rate + tf.reduce_sum(hits)) | ||||||
| self.n_evals.assign(self.n_evals + tf.shape(y_true)[0]) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring on the following lines (44-45) for this parameter still refers to 'batch'. Please update it to 'trip' for consistency with the parameter name change.