-
Notifications
You must be signed in to change notification settings - Fork 14
Add ListMLE Loss #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add ListMLE Loss #130
Conversation
Not sure why tests did not run, commenting so that tests run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, did one pass! Some questions below, let me know what you think.
I'll take a closer look again!
api_gen.py
Outdated
@@ -8,7 +8,7 @@ | |||
|
|||
import os | |||
import shutil | |||
|
|||
import pre_commit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, that shouldn't be there. I will remove it in the next commit along with other changes.
keras_rs/api/layers/__init__.py
Outdated
from keras_rs.src.layers.embedding.distributed_embedding import DistributedEmbedding as DistributedEmbedding | ||
from keras_rs.src.layers.embedding.distributed_embedding_config import FeatureConfig as FeatureConfig | ||
from keras_rs.src.layers.embedding.distributed_embedding_config import TableConfig as TableConfig | ||
from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce as EmbedReduce | ||
from keras_rs.src.layers.feature_interaction.dot_interaction import DotInteraction as DotInteraction | ||
from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross as FeatureCross | ||
from keras_rs.src.layers.retrieval.brute_force_retrieval import BruteForceRetrieval as BruteForceRetrieval | ||
from keras_rs.src.layers.retrieval.hard_negative_mining import HardNegativeMining as HardNegativeMining | ||
from keras_rs.src.layers.retrieval.remove_accidental_hits import RemoveAccidentalHits as RemoveAccidentalHits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, these shouldn't be getting re-formatted, in my opinion. Are you on the correct Ruff version? Same for other files
from keras_rs.src.losses.list_mle_loss import ListMLELoss | ||
|
||
|
||
class ListMLELossTest(testing.TestCase, parameterized.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question - have you verified the outputs with TFRS' ListMLELoss?
) | ||
|
||
self.temperature = temperature | ||
self._epsilon = 1e-10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we define it here, like this, or should we pull it from keras.config.epsilon()
? What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking — importing from keras.config.epsilon() would work, but in this case, defining epsilon locally gives us the flexibility to choose a value other than the default 1e-7.
keras_rs/src/losses/list_mle_loss.py
Outdated
labels_for_sorting = ops.where(valid_mask, labels, ops.full_like(labels, -1e9)) | ||
logits_masked = ops.where(valid_mask, logits, ops.full_like(logits, -1e9)) | ||
|
||
sorted_indices = ops.argsort(-labels_for_sorting, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-labels_for_sorting
--> ops.negative(labels_for_sorting)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't gone through this perfectly, but do you think there is any chance of re-using
def sort_by_scores( |
Completely okay if not! Just wanted to make you aware of this function, and check with you if it can be reused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. I have sorted manually, we can reuse the functionsort_by_scores
.
ef5ee24
to
63388e9
Compare
Added ListMLELoss code to listwise ranking. This code does not consider Lambda weights.
Here is the gist for verified results with TFRS ListMLELoss