Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add observable class #93

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 23.3.0
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies: [flake8-isort]
2 changes: 2 additions & 0 deletions znnl/distance_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""
from znnl.distance_metrics.angular_distance import AngularDistance
from znnl.distance_metrics.cosine_distance import CosineDistance
from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance
from znnl.distance_metrics.distance_metric import DistanceMetric
from znnl.distance_metrics.hyper_sphere_distance import HyperSphere
from znnl.distance_metrics.l_p_norm import LPNorm
Expand All @@ -40,4 +41,5 @@
OrderNDifference.__name__,
MahalanobisDistance.__name__,
HyperSphere.__name__,
CrossEntropyDistance.__name__,
]
11 changes: 11 additions & 0 deletions znnl/distance_metrics/angular_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def __init__(self, points: int = None):
else:
raise ValueError("Invalid points input.")

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "angular_distance"

def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
"""
Call the distance metric.
Expand Down
11 changes: 11 additions & 0 deletions znnl/distance_metrics/cosine_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
)

return 1 - abs(np.divide(numerator, denominator))

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "cosine_distance"
61 changes: 61 additions & 0 deletions znnl/distance_metrics/cross_entropy_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""
import optax

from znnl.distance_metrics.distance_metric import DistanceMetric


class CrossEntropyDistance(DistanceMetric):
"""
Class for the cross entropy distance
"""

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "cross_entropy_distance"

def __call__(self, prediction, target):
"""

Parameters
----------
prediction (batch_size, n_classes)
target

Returns
-------
Softmax cross entropy of the batch.

"""
return optax.softmax_cross_entropy(logits=prediction, labels=target)
27 changes: 26 additions & 1 deletion znnl/distance_metrics/distance_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,37 @@
"""
import jax.numpy as np

from znnl.observables.observable import Observable

class DistanceMetric:

class DistanceMetric(Observable):
"""
Parent class for a ZnRND distance metric.
"""

def __name__(self) -> str:
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "distance_metric"

def __signature__(self) -> tuple:
"""
Signature of the class.

Returns
-------
signature : tuple
The signature of the class.
For the distance metric, it is (1,).
"""
return (1,)

def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
"""
Call the distance metric.
Expand Down
11 changes: 11 additions & 0 deletions znnl/distance_metrics/hyper_sphere_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
return LPNorm(order=self.order)(point_1, point_2) * CosineDistance()(
point_1, point_2
)

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return f"hyper_sphere_distance_{self.order}"
11 changes: 11 additions & 0 deletions znnl/distance_metrics/l_p_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
Array of distances for each point.
"""
return np.linalg.norm(point_1 - point_2, axis=1, ord=self.order)

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return f"lp_norm_{self.order}"
11 changes: 11 additions & 0 deletions znnl/distance_metrics/mahalanobis_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,14 @@ def __call__(self, point_1: np.array, point_2: np.array, **kwargs) -> np.array:
distances.append(distance)

return distances

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "mahalanobis_distance"
11 changes: 11 additions & 0 deletions znnl/distance_metrics/order_n_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,14 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
return np.sum(np.power(diff, self.order), axis=1)
else:
raise ValueError(f"Invalid reduction operation: {self.reduce_operation}")

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "order_{self.order}_difference_{self.reduce_operation}"
4 changes: 2 additions & 2 deletions znnl/loss_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@
from znnl.loss_functions.cosine_distance import CosineDistanceLoss
from znnl.loss_functions.cross_entropy_loss import CrossEntropyLoss
from znnl.loss_functions.l_p_norm import LPNormLoss
from znnl.loss_functions.loss import Loss
from znnl.loss_functions.mahalanobis import MahalanobisLoss
from znnl.loss_functions.mean_power_error import MeanPowerLoss
from znnl.loss_functions.simple_loss import SimpleLoss

__all__ = [
AngleDistanceLoss.__name__,
CosineDistanceLoss.__name__,
LPNormLoss.__name__,
MahalanobisLoss.__name__,
MeanPowerLoss.__name__,
SimpleLoss.__name__,
Loss.__name__,
CrossEntropyLoss.__name__,
]
15 changes: 13 additions & 2 deletions znnl/loss_functions/absolute_angle_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
-------
"""
from znnl.distance_metrics.angular_distance import AngularDistance
from znnl.loss_functions.simple_loss import SimpleLoss
from znnl.loss_functions.loss import Loss


class AngleDistanceLoss(SimpleLoss):
class AngleDistanceLoss(Loss):
"""
Class for the mean power loss
"""
Expand All @@ -39,3 +39,14 @@ def __init__(self):
"""
super(AngleDistanceLoss, self).__init__()
self.metric = AngularDistance()

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "angle_distance_loss"
15 changes: 13 additions & 2 deletions znnl/loss_functions/cosine_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
-------
"""
from znnl.distance_metrics.cosine_distance import CosineDistance
from znnl.loss_functions.simple_loss import SimpleLoss
from znnl.loss_functions.loss import Loss


class CosineDistanceLoss(SimpleLoss):
class CosineDistanceLoss(Loss):
"""
Class for the mean power loss
"""
Expand All @@ -39,3 +39,14 @@ def __init__(self):
"""
super(CosineDistanceLoss, self).__init__()
self.metric = CosineDistance()

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "cosine_distance_loss"
39 changes: 14 additions & 25 deletions znnl/loss_functions/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,11 @@
Summary
-------
"""
import optax
from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance
from znnl.loss_functions.loss import Loss

from znnl.loss_functions.simple_loss import SimpleLoss


class CrossEntropyDistance:
"""
Class for the cross entropy distance
"""

def __call__(self, prediction, target):
"""

Parameters
----------
prediction (batch_size, n_classes)
target

Returns
-------
Softmax cross entropy of the batch.

"""
return optax.softmax_cross_entropy(logits=prediction, labels=target)


class CrossEntropyLoss(SimpleLoss):
class CrossEntropyLoss(Loss):
"""
Class for the cross entropy loss
"""
Expand All @@ -61,3 +39,14 @@ def __init__(self):
"""
super(CrossEntropyLoss, self).__init__()
self.metric = CrossEntropyDistance()

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "cross_entropy_loss"
13 changes: 12 additions & 1 deletion znnl/loss_functions/l_p_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
-------
"""
from znnl.distance_metrics.l_p_norm import LPNorm
from znnl.loss_functions.simple_loss import SimpleLoss
from znnl.loss_functions.loss import SimpleLoss


class LPNormLoss(SimpleLoss):
Expand All @@ -44,3 +44,14 @@ def __init__(self, order: float):
"""
super(LPNormLoss, self).__init__()
self.metric = LPNorm(order=order)

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return f"lp_norm_loss_{self.order}"
Loading