-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/loss soft hkr #86
base: master
Are you sure you want to change the base?
Conversation
deel/lip/losses.py
Outdated
Note that `y_true` should be one-hot encoded or pre-processed with the | ||
`deel.lip.utils.process_labels_for_multi_gpu()` function. | ||
|
||
Using a multi-GPU/TPU strategy requires to set `multi_gpu` to True and to | ||
pre-process the labels `y_true` with the | ||
`deel.lip.utils.process_labels_for_multi_gpu()` function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If multi-GPU is not supported, we might remove/update these lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on multi-gpu removed
deel/lip/losses.py
Outdated
(self.min_margin_v,), | ||
dtype=tf.float32, | ||
constraint=lambda x: tf.clip_by_value(x, 0.005, 1000), | ||
name="moving_mean", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe change the name of this variable to avoid confusion with self.moving_mean
which has the same name
deel/lip/losses.py
Outdated
if self.one_hot_ytrue: | ||
y_true = tf.where(y_true > 0, 1, -1) # switch to +/-1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to use the same line as hinge_preproc? sign = tf.where(y_true > 0, 1, -1)
is done whenever the y_true is 0/1 or -1/1.
If yes, it is then possible to remove self.one_hot_ytrue
from the arguments of the function.
It is also possible to preprocess the y_true labels only once in hkr()
function, instead of twice in preproc_kr and preproc_hinge. Like for F_soft_KR
computed once at the beginning of hkr()
Co-authored-by: cofri <cofri@users.noreply.github.com>
…inibatches is not possible with this loss
temperature (float): factor for softmax temperature | ||
(higher value increases the weight of the highest non y_true logits) | ||
alpha_mean (float): geometric mean factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstrings of temperature
and alpha_mean
are not in the same order than in the __init__()
signature
name="current_mean", | ||
) | ||
|
||
self.temperature = temperature * self.min_margin_v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if it's a problem when saving with get_config()
if the temperature is not the same as initialization of the loss (i.e. keeping temperature * self.min_margin_v
instead of temperature
alone in get_config
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.temperature has to be divided by self.min_margin_v before saving. The only variable is self.current_mean but the question of serialization of the value is good
remove useless docstring Co-authored-by: cofri <cofri@users.noreply.github.com>
Introduce the multiclass loss : MulticlassSoftHKR defined in the paper "On the explainable properties of 1-Lipschitz Neural Networks: An Optimal Transport Perspective" Serrurier et al (Neurips'23) that combine an optimal transport loss with a softmax temperature coefficient (useful when the numer of classes is high like in Imagenet dataset)