Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add observable class #93

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
update pre-commit and fix formatting
  • Loading branch information
SamTov committed May 26, 2023
commit f2861780fc63eecc59d4cf90204ceaf7b96f8a81
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ fail_fast: true

repos:
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 23.3.0
hooks:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.0.0
hooks:
- id: flake8
additional_dependencies: [flake8-isort]
4 changes: 2 additions & 2 deletions znnl/distance_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
"""
from znnl.distance_metrics.angular_distance import AngularDistance
from znnl.distance_metrics.cosine_distance import CosineDistance
from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance
from znnl.distance_metrics.distance_metric import DistanceMetric
from znnl.distance_metrics.hyper_sphere_distance import HyperSphere
from znnl.distance_metrics.l_p_norm import LPNorm
from znnl.distance_metrics.mahalanobis_distance import MahalanobisDistance
from znnl.distance_metrics.order_n_difference import OrderNDifference
from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance

__all__ = [
DistanceMetric.__name__,
Expand All @@ -41,5 +41,5 @@
OrderNDifference.__name__,
MahalanobisDistance.__name__,
HyperSphere.__name__,
CrossEntropyDistance.__name__
CrossEntropyDistance.__name__,
]
4 changes: 2 additions & 2 deletions znnl/distance_metrics/angular_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, points: int = None):
self.normalization = points / np.pi
else:
raise ValueError("Invalid points input.")

def __name__(self):
"""
Name of the class.
Expand All @@ -59,7 +59,7 @@ def __name__(self):
name : str
The name of the class.
"""
return f"angular_distance"
return "angular_distance"

def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
"""
Expand Down
4 changes: 2 additions & 2 deletions znnl/distance_metrics/cosine_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
)

return 1 - abs(np.divide(numerator, denominator))

def __name__(self):
"""
Name of the class.
Expand All @@ -78,4 +78,4 @@ def __name__(self):
name : str
The name of the class.
"""
return f"cosine_distance"
return "cosine_distance"
11 changes: 11 additions & 0 deletions znnl/distance_metrics/cross_entropy_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ class CrossEntropyDistance(DistanceMetric):
Class for the cross entropy distance
"""

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "cross_entropy_distance"

def __call__(self, prediction, target):
"""

Expand Down
2 changes: 1 addition & 1 deletion znnl/distance_metrics/distance_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __name__(self) -> str:
The name of the class.
"""
return "distance_metric"

def __signature__(self) -> tuple:
"""
Signature of the class.
Expand Down
2 changes: 1 addition & 1 deletion znnl/distance_metrics/hyper_sphere_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
return LPNorm(order=self.order)(point_1, point_2) * CosineDistance()(
point_1, point_2
)

def __name__(self):
"""
Name of the class.
Expand Down
2 changes: 1 addition & 1 deletion znnl/distance_metrics/l_p_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
Array of distances for each point.
"""
return np.linalg.norm(point_1 - point_2, axis=1, ord=self.order)

def __name__(self):
"""
Name of the class.
Expand Down
4 changes: 2 additions & 2 deletions znnl/distance_metrics/mahalanobis_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __call__(self, point_1: np.array, point_2: np.array, **kwargs) -> np.array:
distances.append(distance)

return distances

def __name__(self):
"""
Name of the class.
Expand All @@ -76,4 +76,4 @@ def __name__(self):
name : str
The name of the class.
"""
return f"mahalanobis_distance"
return "mahalanobis_distance"
4 changes: 2 additions & 2 deletions znnl/distance_metrics/order_n_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs):
return np.sum(np.power(diff, self.order), axis=1)
else:
raise ValueError(f"Invalid reduction operation: {self.reduce_operation}")

def __name__(self):
"""
Name of the class.
Expand All @@ -89,4 +89,4 @@ def __name__(self):
name : str
The name of the class.
"""
return f"order_{self.order}_difference_{self.reduce_operation}"
return "order_{self.order}_difference_{self.reduce_operation}"
2 changes: 1 addition & 1 deletion znnl/loss_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from znnl.loss_functions.cosine_distance import CosineDistanceLoss
from znnl.loss_functions.cross_entropy_loss import CrossEntropyLoss
from znnl.loss_functions.l_p_norm import LPNormLoss
from znnl.loss_functions.loss import Loss
from znnl.loss_functions.mahalanobis import MahalanobisLoss
from znnl.loss_functions.mean_power_error import MeanPowerLoss
from znnl.loss_functions.loss import Loss

__all__ = [
AngleDistanceLoss.__name__,
Expand Down
18 changes: 9 additions & 9 deletions znnl/loss_functions/absolute_angle_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def __init__(self):
self.metric = AngularDistance()

def __name__(self):
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return f"angle_distance_loss"
"""
Name of the class.

Returns
-------
name : str
The name of the class.
"""
return "angle_distance_loss"
2 changes: 1 addition & 1 deletion znnl/loss_functions/cosine_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def __name__(self):
name : str
The name of the class.
"""
return f"cosine_distance_loss"
return "cosine_distance_loss"
6 changes: 2 additions & 4 deletions znnl/loss_functions/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
Summary
-------
"""
import optax

from znnl.loss_functions.loss import Loss
from znnl.distance_metrics.cross_entropy_distance import CrossEntropyDistance
from znnl.loss_functions.loss import Loss


class CrossEntropyLoss(Loss):
Expand All @@ -51,4 +49,4 @@ def __name__(self):
name : str
The name of the class.
"""
return f"cross_entropy_loss"
return "cross_entropy_loss"
2 changes: 1 addition & 1 deletion znnl/loss_functions/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __name__(self) -> str:
The name of the class.
"""
return "loss_parent"

def __signature__(self) -> tuple:
"""
Signature of the class.
Expand Down
2 changes: 1 addition & 1 deletion znnl/loss_functions/mahalanobis.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def __name__(self):
name : str
The name of the class.
"""
return f"mahalanobis_loss"
return "mahalanobis_loss"
2 changes: 1 addition & 1 deletion znnl/loss_functions/mean_power_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def __name__(self):
name : str
The name of the class.
"""
return f"mean_power_loss_{self.order}"
return "mean_power_loss_{self.order}"
4 changes: 1 addition & 3 deletions znnl/observables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,4 @@
"""
from znnl.observables.observable import Observable

__all__ = [
Observable.__name__
]
__all__ = [Observable.__name__]
6 changes: 3 additions & 3 deletions znnl/observables/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __name__(self) -> str:
The name of the observable.
"""
raise NotImplementedError("Implemented in child class.")

@classmethod
def __signature__(self, data_set: dict) -> tuple:
"""
Expand All @@ -63,7 +63,7 @@ def __signature__(self, data_set: dict) -> tuple:
The signature of the observable.
"""
raise NotImplementedError("Implemented in child class.")

@classmethod
def __call__(self, data_set: dict) -> Union[str, np.ndarray, float]:
"""
Expand All @@ -78,6 +78,6 @@ def __call__(self, data_set: dict) -> Union[str, np.ndarray, float]:
-------
value : Union[str, np.ndarray, float]
The value of the observable.

"""
raise NotImplementedError("Implemented in child class.")
6 changes: 3 additions & 3 deletions znnl/training_strategies/partitioned_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(
Random seed for the RNG. Uses a random int if not specified.
recursive_mode : RecursiveMode
Defining the recursive mode that can be used in training.
If the recursive mode is used, the training will be performed until a
condition is fulfilled.
If the recursive mode is used, the training will be performed
until a condition is fulfilled.
The loss value at which point you consider the model trained.
disable_loading_bar : bool
Disable the output visualization of the loading bar.
Expand Down Expand Up @@ -244,7 +244,7 @@ def train_model(
Number of epochs to train over.
Each epoch defines a training phase.
train_ds_selection : list
(default = [slice(-1, None, None), slice(None, None, None)])
default = [slice(-1, None, None), slice(None, None, None)]
The train is selected by a np.array of indices or slices.
Each slice or array defines a training phase.
batch_size : list (default = [1, 1])
Expand Down