Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LazyLoader for delayed imports of slow modules #4653

Merged
merged 6 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions cirq-core/cirq/_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,55 @@ def wrap_func(module: ModuleType) -> Optional[ModuleType]:
for module in execute_list:
if module.__loader__ is not None and hasattr(module.__loader__, 'exec_module'):
cast(Loader, module.__loader__).exec_module(module) # Calls back into wrap_func


class LazyLoader(ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies.

This class is a modified version of a similar class in TensorFlow.

To use, instead of importing the module normally
```
import heavy_module
```
define the module
```
heavy_module = LazyLoader("heavy_module", globals(), "mypackage.heavy_module")
```
"""

def __init__(self, local_name, parent_module_globals, name):
"""Create the LazyLoader module.

Args:
local_name: The local name that the module will be refered to as.
parent_module_globals: The globals of the module where this should be imported.
Typically this will be globals().
name: The full qualified name of the module.
"""
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._module = None
super().__init__(name)

def _load(self):
"""Load the module and insert it into the parent's globals."""
# Import the target module and insert it into the parent's namespace
if self._module:
return self._module
self._module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = self._module

# Update this object's dict so that if someone keeps a reference to the LazyLoader,
# lookups are efficient (__getattr__ is only called on lookups that fail).
self.__dict__.update(self._module.__dict__)

return self._module

def __getattr__(self, item):
module = self._load()
return getattr(module, item)

def __dir__(self):
module = self._load()
return dir(module)
30 changes: 30 additions & 0 deletions cirq-core/cirq/_import_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cirq import _import


def test_lazy_loader():
linalg = _import.LazyLoader("linalg", globals(), "scipy.linalg")
linalg.fun = 1
assert linalg._module is None
assert "linalg" not in linalg.__dict__

linalg.det([[1]])

assert linalg._module is not None
assert globals()["linalg"] == linalg._module
assert "fun" in linalg.__dict__
assert "LinAlgError" in dir(linalg)
assert linalg.fun == 1
10 changes: 5 additions & 5 deletions cirq-core/cirq/experiments/t1_decay_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import numpy as np


from cirq import circuits, ops, study, value
from cirq import circuits, ops, study, value, _import
from cirq._compat import proper_repr

if TYPE_CHECKING:
import cirq

# We initialize optimize lazily, otherwise it slows global import speed.
optimize = _import.LazyLoader("optimize", globals(), "scipy.optimize")


# TODO(#3388) Add documentation for Raises.
# pylint: disable=missing-raises-doc
Expand Down Expand Up @@ -130,10 +133,7 @@ def exp_decay(x, t1):

# Fit to exponential decay to find the t1 constant
try:
# Import scipy.optimize here to avoid costly module level import.
import scipy.optimize

popt, _ = scipy.optimize.curve_fit(exp_decay, xs, probs, p0=[t1_guess])
popt, _ = optimize.curve_fit(exp_decay, xs, probs, p0=[t1_guess])
t1 = popt[0]
return t1
except RuntimeError:
Expand Down
18 changes: 8 additions & 10 deletions cirq-core/cirq/experiments/xeb_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import pandas as pd
import sympy
from cirq import ops, protocols
from cirq import ops, protocols, _import
from cirq.circuits import Circuit
from cirq.experiments.xeb_simulation import simulate_2q_xeb_circuits

Expand All @@ -36,6 +36,10 @@
import multiprocessing
import scipy.optimize

# We initialize these lazily, otherwise they slow global import speed.
optimize = _import.LazyLoader("optimize", globals(), "scipy.optimize")
stats = _import.LazyLoader("stats", globals(), "scipy.stats")

THETA_SYMBOL, ZETA_SYMBOL, CHI_SYMBOL, GAMMA_SYMBOL, PHI_SYMBOL = sympy.symbols(
'theta zeta chi gamma phi'
)
Expand Down Expand Up @@ -410,10 +414,7 @@ def _mean_infidelity(angles):
print(f"Loss: {loss:7.3g}", flush=True)
return loss

# Import scipy.optimize here to avoid costly top level moule import.
import scipy.optimize

optimization_result = scipy.optimize.minimize(
optimization_result = optimize.minimize(
_mean_infidelity,
x0=x0,
options={
Expand Down Expand Up @@ -574,15 +575,12 @@ def _fit_exponential_decay(
cycle_depths_pos = cycle_depths[positives]
log_fidelities = np.log(fidelities[positives])

# We import here to avoid costly module level load time dependency on scipy.stats.
import scipy.stats

slope, intercept, _, _, _ = scipy.stats.linregress(cycle_depths_pos, log_fidelities)
slope, intercept, _, _, _ = stats.linregress(cycle_depths_pos, log_fidelities)
layer_fid_0 = np.clip(np.exp(slope), 0, 1)
a_0 = np.clip(np.exp(intercept), 0, 1)

try:
(a, layer_fid), pcov = scipy.optimize.curve_fit(
(a, layer_fid), pcov = optimize.curve_fit(
exponential_decay,
cycle_depths,
fidelities,
Expand Down
19 changes: 10 additions & 9 deletions cirq-core/cirq/qis/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
"""Measures on and between quantum states and operations."""


from typing import Optional, TYPE_CHECKING, Tuple

import numpy as np
import scipy

from cirq import protocols, value
from cirq import protocols, value, _import
from cirq.qis.states import (
QuantumState,
infer_qid_shape,
Expand All @@ -27,13 +27,18 @@
validate_normalized_state_vector,
)

# We initialize these lazily, otherwise they slow global import speed.
stats = _import.LazyLoader("stats", globals(), "scipy.stats")
linalg = _import.LazyLoader("linalg", globals(), "scipy.linalg")


if TYPE_CHECKING:
import cirq


def _sqrt_positive_semidefinite_matrix(mat: np.ndarray) -> np.ndarray:
"""Square root of a positive semidefinite matrix."""
eigs, vecs = scipy.linalg.eigh(mat)
eigs, vecs = linalg.eigh(mat)
return vecs @ (np.sqrt(np.abs(eigs)) * vecs).T.conj()


Expand Down Expand Up @@ -237,7 +242,7 @@ def _fidelity_state_vectors_or_density_matrices(state1: np.ndarray, state2: np.n
elif state1.ndim == 2 and state2.ndim == 2:
# Both density matrices
state1_sqrt = _sqrt_positive_semidefinite_matrix(state1)
eigs = scipy.linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt)
eigs = linalg.eigvalsh(state1_sqrt @ state2 @ state1_sqrt)
trace = np.sum(np.sqrt(np.abs(eigs)))
return trace ** 2
raise ValueError(
Expand Down Expand Up @@ -277,11 +282,7 @@ def von_neumann_entropy(
qid_shape = (state.shape[0],)
validate_density_matrix(state, qid_shape=qid_shape, dtype=state.dtype, atol=atol)
eigenvalues = np.linalg.eigvalsh(state)

# We import here to avoid a costly module level load time dependency on scipy.stats.
import scipy.stats

return scipy.stats.entropy(np.abs(eigenvalues), base=2)
return stats.entropy(np.abs(eigenvalues), base=2)
if validate:
_ = quantum_state(state, qid_shape=qid_shape, copy=False, validate=True, atol=atol)
return 0.0
Expand Down