Skip to content

Commit

Permalink
Merge pull request #410 from JaxGaussianProcesses/henry/fix_config_im…
Browse files Browse the repository at this point in the history
…port
  • Loading branch information
henrymoss authored Nov 8, 2023
2 parents 5758238 + 7dc9bc0 commit f40091e
Show file tree
Hide file tree
Showing 30 changed files with 31 additions and 31 deletions.
2 changes: 1 addition & 1 deletion docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/likelihoods_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# +
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
4 changes: 2 additions & 2 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
#
# Surface drifters are measurement devices that measure the dynamics and circulation patterns of the world's oceans. Studying and predicting ocean currents are important to climate research, for example, forecasting and predicting oil spills, oceanographic surveying of eddies and upwelling, or providing information on the distribution of biomass in ecosystems. We will be using the [Gulf Drifters Open dataset](https://zenodo.org/record/4421585), which contains all publicly available surface drifter trajectories from the Gulf of Mexico spanning 28 years.
# %%
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
from dataclasses import dataclass

from jax import hessian
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow_probability.substrates.jax as tfp
from jax.config import config
from jax import config
from jaxtyping import install_import_hook

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/regression_mo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#
# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/uncollapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# %%
# Enable Float64 for more stable matrix inversions.
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_citations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
except ImportError:
ValidationErrors = ValueError

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from dataclasses import dataclass

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# ==============================================================================


from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Type,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import typing as tp

import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import numpy as np
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from cola.ops import Dense
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
field,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
from jaxtyping import (
Array,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_non_euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# # limitations under the License.

from cola.ops import I_like
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import networkx as nx
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from cola.ops import LinearOperator
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_stationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from cola.ops import LinearOperator
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
List,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down
2 changes: 1 addition & 1 deletion tests/test_objectives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
Expand Down
2 changes: 1 addition & 1 deletion tests/test_variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Tuple,
)

from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand Down

0 comments on commit f40091e

Please sign in to comment.