Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class Dimension:
# ---- INITIALISATION ---- #

def __init__(self, dims):
self._dims = dims
self._dims: np.ndarray = np.asarray(dims)

# ---- METHODS ---- #
def get_dimension(self, d):
Expand Down Expand Up @@ -318,7 +318,8 @@ def is_dimensionless(self):
Normally, instead one should check dimension for being identical to
`DIMENSIONLESS`.
"""
return all([x == 0 for x in self._dims])
return np.allclose(self._dims, 0)
# return all([x == 0 for x in self._dims])

@property
def dim(self):
Expand Down Expand Up @@ -377,25 +378,24 @@ def __str__(self):
# Note that none of the dimension arithmetic objects do sanity checking
# on their inputs, although most will throw an exception if you pass the
# wrong sort of input
def __mul__(self, value):
return get_or_create_dimension([x + y for x, y in zip(self._dims, value._dims)])
def __mul__(self, value: 'Dimension'):
assert isinstance(value, Dimension), "Can only divide by a Dimension object"
return get_or_create_dimension(self._dims + value._dims)

def __div__(self, value):
return get_or_create_dimension([x - y for x, y in zip(self._dims, value._dims)])
def __div__(self, value: 'Dimension'):
assert isinstance(value, Dimension), "Can only divide by a Dimension object"
return get_or_create_dimension(self._dims - value._dims)

def __truediv__(self, value):
def __truediv__(self, value: 'Dimension'):
return self.__div__(value)

def __pow__(self, value: numbers.Number | jax.Array):
if value is DIMENSIONLESS:
return self
if isinstance(value, jax.core.Tracer):
value = jnp.array(value) # TODO: check jit
else:
value = np.array(value) # TODO: check jit
def __pow__(self, value: numbers.Number | np.ndarray):
if _is_tracer(value):
raise TypeError(f"Cannot use a tracer {value} as an exponent, please use a constant.")
value = np.array(value)
if value.size > 1:
raise TypeError("Too many exponents")
return get_or_create_dimension([x * value for x in self._dims])
return get_or_create_dimension(self._dims * value)

def __imul__(self, value):
raise NotImplementedError("Dimension object is immutable")
Expand Down Expand Up @@ -436,7 +436,7 @@ def __deepcopy__(self, memodict):
return self

def __hash__(self):
return hash(self._dims)
return hash(self._dims.tobytes())


@set_module_as('brainunit')
Expand Down Expand Up @@ -483,6 +483,7 @@ def get_or_create_dimension(*args, **kwds) -> Dimension:
e.g. length, metre, and m all refer to the same thing here.
"""
if len(args):
assert len(args) == 1, "Only one argument allowed"
# initialisation by list
dims = args[0]
try:
Expand All @@ -492,26 +493,18 @@ def get_or_create_dimension(*args, **kwds) -> Dimension:
raise TypeError("Need a sequence of exactly 7 items")
else:
# initialisation by keywords
dims = [0, 0, 0, 0, 0, 0, 0]
dims = np.asarray([0, 0, 0, 0, 0, 0, 0])
for k in kwds:
# _dim2index stores the index of the dimension with name 'k'
dims[_dim2index[k]] = kwds[k]

dims = tuple(dims)
dims = np.asarray(dims)
new_dim = Dimension(dims)
return new_dim

# # check whether this Dimension object has already been created
# if dims in _dimensions:
# return _dimensions[dims]
# else:
# new_dim = Dimension(dims)
# _dimensions[dims] = new_dim
# return new_dim


'''The dimensionless unit, used for quantities without a unit.'''
DIMENSIONLESS = Dimension((0, 0, 0, 0, 0, 0, 0))
DIMENSIONLESS = Dimension(np.asarray([0, 0, 0, 0, 0, 0, 0]))


class DimensionMismatchError(Exception):
Expand Down Expand Up @@ -1557,7 +1550,7 @@ def __mul__(self, other) -> 'Unit' | Quantity:
return Unit(dim, scale=scale, base=self.base, name=name, dispname=dispname, iscompound=iscompound)

elif isinstance(other, Quantity):
return Quantity(other._mantissa, unit=(self * other.unit))
return Quantity(other.mantissa, unit=(self * other.unit))

elif isinstance(other, Dimension):
raise TypeError(f"unit {self} cannot multiply by a Dimension {other}.")
Expand All @@ -1570,7 +1563,7 @@ def __rmul__(self, other) -> 'Unit' | Quantity:
if isinstance(other, Unit):
return other.__mul__(self)
elif isinstance(other, Quantity):
return Quantity(other._mantissa, unit=(other.unit * self))
return Quantity(other.mantissa, unit=(other.unit * self))
else:
return Quantity(other, unit=self)

Expand Down Expand Up @@ -1630,7 +1623,7 @@ def __rdiv__(self, other) -> 'Unit' | Quantity:
return other.__div__(self)

elif isinstance(other, Quantity):
return Quantity(other._mantissa, unit=(other.unit / self))
return Quantity(other.mantissa, unit=(other.unit / self))

else:
return Quantity(other, unit=(1 / self))
Expand Down
18 changes: 9 additions & 9 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,15 @@ def test_display(self):
assert_equal(display_in_unit(3.0 * bu.kmeter / 130.51 * bu.meter * bu.cm ** -1), '0.02298675 * 10.0^5 * meter')
assert_equal(display_in_unit(1. * bu.joule / bu.kelvin), '1. * joule / kelvin')

def test_display2(self):

@jax.jit
def f(s):
a = bu.ms ** s
print(a)
return bu.Quantity(1., unit=a)

f(2)
# def test_display2(self):
#
# @jax.jit
# def f(s):
# a = bu.ms ** s
# print(a)
# return bu.Quantity(1., unit=a)
#
# f(2)

def test_unary_operations(self):
q = Quantity(5, unit=mV)
Expand Down
Loading