Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 119 additions & 39 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _short_str(arr):
Return a short string representation of an array, suitable for use in
error messages.
"""
arr = arr.value if isinstance(arr, Quantity) else arr
arr = np.asanyarray(arr)
old_printoptions = jnp.get_printoptions()
jnp.set_printoptions(edgeitems=2, threshold=5)
Expand All @@ -112,7 +113,7 @@ def _short_str(arr):
return arr_string


def get_unit_for_display(d):
def get_dim_for_display(d):
"""
Return a string representation of an appropriate unscaled unit or ``'1'``
for a dimensionless array.
Expand Down Expand Up @@ -181,6 +182,13 @@ def get_unit_for_display(d):
"cd": 6,
}

# Length (meter)
# Mass (kilogram)
# Time (second)
# Current (ampere)
# Temperature (Kelvin)
# Amount of substance (mole)
# Luminous intensity (candela)
_ilabel = ["m", "kg", "s", "A", "K", "mol", "cd"]

# The same labels with the names used for constructing them in Python code
Expand Down Expand Up @@ -453,6 +461,8 @@ def get_or_create_dimension(*args, **kwds):

'''The dimensionless unit, used for quantities without a unit.'''
DIMENSIONLESS = Dimension((0, 0, 0, 0, 0, 0, 0))

'''The dictionary of all existing Dimension objects.'''
_dimensions = {(0, 0, 0, 0, 0, 0, 0): DIMENSIONLESS}


Expand Down Expand Up @@ -492,16 +502,16 @@ def __str__(self):
if len(self.dims) == 0:
pass
elif len(self.dims) == 1:
s += f" (unit is {get_unit_for_display(self.dims[0])}"
s += f" (unit is {get_dim_for_display(self.dims[0])}"
elif len(self.dims) == 2:
d1, d2 = self.dims
s += (
f" (units are {get_unit_for_display(d1)} and {get_unit_for_display(d2)}"
f" (units are {get_dim_for_display(d1)} and {get_dim_for_display(d2)}"
)
else:
s += (
" (units are"
f" {' '.join([f'({get_unit_for_display(d)})' for d in self.dims])}"
f" {' '.join([f'({get_dim_for_display(d)})' for d in self.dims])}"
)
if len(self.dims):
s += ")."
Expand All @@ -510,7 +520,7 @@ def __str__(self):

def get_dim(obj) -> Dimension:
"""
Return the unit of any object that has them.
Return the dimension of any object that has them.

Slightly more general than `Array.dimensions` because it will
return `DIMENSIONLESS` if the object is of number type but not a `Array`
Expand Down Expand Up @@ -741,9 +751,9 @@ def in_best_unit(x, precision=None):
return x.repr_in_unit(u, precision=precision)


def array_with_unit(
def array_with_dim(
floatval,
unit: Dimension,
dim: Dimension,
dtype: jax.typing.DTypeLike = None
) -> 'Quantity':
"""
Expand All @@ -757,8 +767,8 @@ def array_with_unit(
----------
floatval : `float`
The floating point value of the array.
unit: Dimension
The unit dimensions of the array.
dim: Dimension
The dim dimensions of the array.
dtype: `dtype`, optional
The data type of the array.

Expand All @@ -770,10 +780,10 @@ def array_with_unit(
Examples
--------
>>> from brainunit import *
>>> array_with_unit(0.001, volt.dim)
>>> array_with_dim(0.001, volt.dim)
1. * mvolt
"""
return Quantity(floatval, dim=get_or_create_dimension(unit._dims), dtype=dtype)
return Quantity(floatval, dim=get_or_create_dimension(dim._dims), dtype=dtype)


def is_unitless(obj) -> bool:
Expand Down Expand Up @@ -1054,6 +1064,34 @@ def dim(self, *args):
raise NotImplementedError("Cannot set the dimension of a Quantity object directly,"
"Please create a new Quantity object with the value you want.")

@property
def unit(self) -> 'Unit':
return Unit(1., self.dim, register=False)

@unit.setter
def unit(self, *args):
# Do not support setting the unit directly
raise NotImplementedError("Cannot set the unit of a Quantity object directly,"
"Please create a new Quantity object with the unit you want.")

def to_value(self, unit: 'Unit') -> jax.Array | numbers.Number:
"""
Convert the value of the array to a new unit.

Examples::

>>> a = jax.numpy.array([1, 2, 3]) * mV
>>> a.to_value(volt)
array([0.001, 0.002, 0.003])

Args:
unit: The new unit to convert the value of the array to.

Returns:
The value of the array in the new unit.
"""
return self.value / unit.value

@staticmethod
def with_units(value, *args, **keywords):
"""
Expand Down Expand Up @@ -1506,9 +1544,7 @@ def __radd__(self, oc):

def __iadd__(self, oc):
# a += b
r = self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True)
self.update_value(r.value)
return self
return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True)

def __sub__(self, oc):
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-")
Expand All @@ -1518,9 +1554,7 @@ def __rsub__(self, oc):

def __isub__(self, oc):
# a -= b
r = self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True)
self.update_value(r.value)
return self
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True)

def __mul__(self, oc):
r = self._binary_operation(oc, operator.mul, operator.mul)
Expand Down Expand Up @@ -1731,7 +1765,7 @@ def __round__(self, ndigits: int = None) -> 'Quantity':
return Quantity(self.value.__round__(ndigits), dim=self.dim)

def __reduce__(self):
return array_with_unit, (self.value, self.dim, None)
return array_with_dim, (self.value, self.dim, None)

# ----------------------- #
# NumPy methods #
Expand Down Expand Up @@ -1963,10 +1997,19 @@ def take(
) -> 'Quantity':
"""Return an array formed from the elements of a at the given indices."""
if isinstance(fill_value, Quantity):
fail_for_dimension_mismatch(self, fill_value, "take")
fill_value = fill_value.value
return Quantity(jnp.take(self.value, indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value), dim=self.dim)
elif fill_value is not None:
if not self.is_unitless:
raise TypeError(f"fill_value must be a Quantity when the unit {self.unit}. But got {fill_value}")
return Quantity(
jnp.take(self.value,
indices=indices, axis=axis, mode=mode,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted,
fill_value=fill_value),
dim=self.dim
)

def tolist(self):
"""Return the array as an ``a.ndim``-levels deep nested list of Python scalars.
Expand Down Expand Up @@ -2226,9 +2269,11 @@ def view(self, *args, dtype=None) -> 'Quantity':
# NumPy support
# ------------------

def to_numpy(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> np.ndarray:
def to_numpy(
self,
unit: Optional['Unit'] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
) -> np.ndarray:
"""
Remove the unit and convert to ``numpy.ndarray``.

Expand All @@ -2240,14 +2285,19 @@ def to_numpy(self,
The numpy.ndarray.
"""
if unit is None:
assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to "
f"NumPy arrays when 'unit' is not provided. But got {self}")
return np.asarray(self.value, dtype=dtype)
else:
fail_for_dimension_mismatch(self, unit, "to_numpy")
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return np.asarray(self / unit, dtype=dtype)

def to_jax(self,
dtype: Optional[jax.typing.DTypeLike] = None,
unit: Optional['Unit'] = None) -> jax.Array:
def to_jax(
self,
unit: Optional['Unit'] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
) -> jax.Array:
"""
Remove the unit and convert to ``jax.Array``.

Expand All @@ -2259,20 +2309,50 @@ def to_jax(self,
The jax.Array.
"""
if unit is None:
assert self.dim == DIMENSIONLESS, (f"only dimensionless quantities can be converted to "
f"JAX arrays when 'unit' is not provided. But got {self}")
return jnp.asarray(self.value, dtype=dtype)
else:
fail_for_dimension_mismatch(self, unit, "to_jax")
assert isinstance(unit, Unit), f"unit must be a Unit object, but got {type(unit)}"
return jnp.asarray(self / unit, dtype=dtype)

def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray:
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
return np.asarray(self.value, dtype=dtype)
if self.dim == DIMENSIONLESS:
return np.asarray(self.value, dtype=dtype)
else:
raise TypeError(
f"only dimensionless quantities can be "
f"converted to NumPy arrays. But got {self}"
)

def __float__(self):
return self.value.__float__()
if self.dim == DIMENSIONLESS and self.ndim == 0:
return float(self.value)
else:
raise TypeError(
"only dimensionless scalar quantities can be "
f"converted to Python scalars. But got {self}"
)

def __int__(self):
if self.dim == DIMENSIONLESS and self.ndim == 0:
return int(self.value)
else:
raise TypeError(
"only dimensionless scalar quantities can be "
f"converted to Python scalars. But got {self}"
)

def __index__(self):
return operator.index(self.value)
if self.dim == DIMENSIONLESS:
return operator.index(self.value)
else:
raise TypeError(
"only dimensionless quantities can be "
f"converted to a Python index. But got {self}"
)

# ----------------------
# PyTorch compatibility
Expand Down Expand Up @@ -2518,6 +2598,7 @@ def __init__(
dispname: str = None,
iscompound: bool = None,
dtype: jax.typing.DTypeLike = None,
register: bool = True,
):
if dim is None:
dim = DIMENSIONLESS
Expand All @@ -2543,7 +2624,7 @@ def __init__(

super().__init__(value, dtype=dtype, dim=dim)

if _auto_register_unit:
if _auto_register_unit and register:
register_new_unit(self)

@staticmethod
Expand Down Expand Up @@ -2783,10 +2864,11 @@ def add(self, u: Unit):
if isinstance(u.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
self.units_for_dimensions[u.dim][1.] = u
else:
self.units_for_dimensions[u.dim][float(u)] = u
self.units_for_dimensions[u.dim][float(u.value)] = u

def __getitem__(self, x):
"""Returns the best unit for array x
"""
Returns the best unit for array x

The algorithm is to consider the value:

Expand Down Expand Up @@ -3005,9 +3087,7 @@ def new_f(*args, **kwds):
v = Quantity(v)
except TypeError:
if have_same_unit(au[n], 1):
raise TypeError(
f"Argument {n} is not a unitless value/array."
)
raise TypeError(f"Argument {n} is not a unitless value/array.")
else:
raise TypeError(
f"Argument '{n}' is not a array, "
Expand Down Expand Up @@ -3053,9 +3133,9 @@ def new_f(*args, **kwds):
f"the argument '{k}' to have the same "
f"units as argument '{au[k]}', but "
f"argument '{k}' has "
f"unit {get_unit_for_display(d1)}, "
f"unit {get_dim_for_display(d1)}, "
f"while argument '{au[k]}' "
f"has unit {get_unit_for_display(d2)}."
f"has unit {get_dim_for_display(d2)}."
)
raise DimensionMismatchError(error_message)
elif not have_same_unit(newkeyset[k], au[k]):
Expand Down Expand Up @@ -3087,7 +3167,7 @@ def new_f(*args, **kwds):
)
raise TypeError(error_message)
elif not have_same_unit(result, expected_result):
unit = get_unit_for_display(expected_result)
unit = get_dim_for_display(expected_result)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
Expand Down
28 changes: 28 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import unittest

import brainunit as bu


class TestQuantity(unittest.TestCase):
def test_dim(self):
a = [1, 2.] * bu.ms

with self.assertRaises(NotImplementedError):
a.dim = bu.mV.dim


Loading