Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zqs 920 loss function wrapper #33

Merged
merged 8 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/gradient_descent_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def gradient_descent_optimizer(
init_params: np.ndarray,
loss_function: Callable,
loss_function: Callable[[np.ndarray], float],
n_iters: int,
learning_rate: float = 0.1,
full_gradient_function: Optional[Callable] = None,
Expand Down
263 changes: 158 additions & 105 deletions docs/examples/orqviz_tutorial_cirq.ipynb

Large diffs are not rendered by default.

271 changes: 167 additions & 104 deletions docs/examples/orqviz_tutorial_orquestra.ipynb

Large diffs are not rendered by default.

128 changes: 64 additions & 64 deletions docs/examples/orqviz_tutorial_pennylane.ipynb

Large diffs are not rendered by default.

287 changes: 156 additions & 131 deletions docs/examples/orqviz_tutorial_qiskit.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion src/orqviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@
plot_utils,
plots,
scans,
utils,
)
41 changes: 41 additions & 0 deletions src/orqviz/loss_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time
from typing import Callable, Optional

import numpy as np


def _calculate_new_average(
previous_average: Optional[float], count: int, new_value: float
) -> float:
if previous_average is None:
return new_value
else:
return (count * previous_average + new_value) / (count + 1)


class LossFunctionWrapper:
def __init__(self, loss_function: Callable, *args, **kwargs):
def wrapped_loss_function(params: np.ndarray) -> float:
return loss_function(params, *args, **kwargs)

self.loss_function = wrapped_loss_function
self.call_count = 0
self.average_call_time: Optional[float] = None
self.min_value: Optional[float] = None

def __call__(self, params: np.ndarray) -> float:
start_time = time.perf_counter()
value = self.loss_function(params)
total_time = time.perf_counter() - start_time
self.average_call_time = _calculate_new_average(
self.average_call_time, self.call_count, total_time
)
self.call_count += 1
if self.min_value is None or value < self.min_value:
self.min_value = value
return value

def reset(self) -> None:
self.call_count = 0
self.average_call_time = None
self.min_value = None
1 change: 0 additions & 1 deletion src/orqviz/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
import warnings
from typing import Union

Expand Down
1 change: 1 addition & 0 deletions tests/orqviz/io_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import numpy as np
import pytest

from orqviz.elastic_band.data_structures import Chain
from orqviz.hessians import get_Hessian
Expand Down
94 changes: 94 additions & 0 deletions tests/orqviz/loss_function_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np

from orqviz.loss_function import LossFunctionWrapper


def mock_function(params, b, c):
return np.sum(params * b + c)


class TestLossFunctionWrapper:
def test_if_wraps_args(self):
b = 10
c = 2
params = np.array([1, 2, 3, 4])
target_value = mock_function(params, b, c)

loss_function = LossFunctionWrapper(mock_function, b, c)

result = loss_function(params=params)
assert result == target_value

def test_if_wraps_kwargs(self):
b = 10
c = 2
params = np.array([1, 2, 3, 4])
target_value = mock_function(params, b, c)

loss_function = LossFunctionWrapper(mock_function, b=b, c=c)

result = loss_function(params=params)
assert result == target_value

def test_if_wraps_both_args_and_kwargs(self):
b = 10
c = 2
params = np.array([1, 2, 3, 4])
target_value = mock_function(params, b, c)

loss_function = LossFunctionWrapper(mock_function, b, c=c)

result = loss_function(params=params)

assert result == target_value

def test_tracks_average_call_time(self):
b = 10
c = 2
call_count = 10

loss_function = LossFunctionWrapper(mock_function, b, c=c)
for i in range(call_count):
_ = loss_function(params=np.random.random(i))

assert loss_function.average_call_time is not None

def test_tracks_call_count(self):
b = 10
c = 2
call_count = 10

loss_function = LossFunctionWrapper(mock_function, b, c=c)
for i in range(call_count):
_ = loss_function(params=np.random.random(i))

assert loss_function.call_count == call_count

def test_tracks_min_value(self):
b = 10
c = 2
call_count = 10

loss_function = LossFunctionWrapper(mock_function, b, c=c)
min_value = np.inf
for i in range(call_count):
value = loss_function(params=np.random.random(i))
if value < min_value:
min_value = value

assert loss_function.min_value == min_value

def test_reset(self):
b = 10
c = 2
call_count = 10

loss_function = LossFunctionWrapper(mock_function, b, c=c)
for i in range(call_count):
_ = loss_function(params=np.random.random(i))

loss_function.reset()

assert loss_function.average_call_time is None
assert loss_function.call_count == 0
assert loss_function.min_value is None
1 change: 0 additions & 1 deletion tests/orqviz/scans/scans_plots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import matplotlib.pyplot as plt
import numpy as np
import pytest

from orqviz.io import load_viz_object, save_viz_object
from orqviz.scans import perform_1D_scan, perform_2D_interpolation
Expand Down
5 changes: 0 additions & 5 deletions tests/orqviz/scans/scans_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
perform_2D_interpolation,
)
from orqviz.scans.data_structures import Scan1DResult, Scan2DResult
from orqviz.scans.plots import (
plot_1D_scan_result,
plot_2D_interpolation_result,
plot_2D_scan_result,
)
from orqviz.scans.scans_2D import perform_2D_scan


Expand Down