Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c17b9fa
Update _compat_numpy.py
Routhleck Jun 10, 2024
72ccd90
Update _compat_numpy.py
Routhleck Jun 10, 2024
4ee28ed
Update
Routhleck Jun 11, 2024
fc2b978
Update _compat_numpy.py
Routhleck Jun 11, 2024
017eb6f
Fix
Routhleck Jun 11, 2024
22ef3b6
Update brainunit.math.rst
Routhleck Jun 11, 2024
cbbcfc9
Update _compat_numpy.py
Routhleck Jun 11, 2024
f6a0040
Update _unit_test.py
Routhleck Jun 11, 2024
fea30e2
Restruct
Routhleck Jun 11, 2024
229ad90
Merge branch 'main' of https://github.com/brainpy/brainunit into upda…
Routhleck Jun 11, 2024
337b365
Update
Routhleck Jun 11, 2024
6fc6add
Fix bugs
Routhleck Jun 11, 2024
08b90cd
Fix bugs in Python 3.9
Routhleck Jun 11, 2024
d702944
Update _compat_numpy_funcs_bit_operation.py
Routhleck Jun 11, 2024
b0154ab
Update _compat_numpy_funcs_bit_operation.py
Routhleck Jun 11, 2024
d0fcce6
Fix logic of `asarray`
Routhleck Jun 11, 2024
d3f38e9
update __str__
chaoming0625 Jun 11, 2024
c0f8171
update
chaoming0625 Jun 11, 2024
1b0d380
Update array creation funcs
Routhleck Jun 11, 2024
9cb9b3e
Merge branch 'main' into fix-asarray
Routhleck Jun 12, 2024
bb439c6
Update _compat_numpy_test.py
Routhleck Jun 12, 2024
ea4e9d5
Add magnitude conversion for `asarray`
Routhleck Jun 12, 2024
b967edc
Update _compat_numpy_array_creation.py
Routhleck Jun 12, 2024
512f734
Update _compat_numpy_test.py
Routhleck Jun 12, 2024
1647baa
Fix bugs
Routhleck Jun 12, 2024
f5bbe80
Merge branch 'fix-asarray' into fix
chaoming0625 Jun 12, 2024
d1257d0
fix tests
chaoming0625 Jun 12, 2024
5f0dcb2
fix tests
chaoming0625 Jun 12, 2024
deb6153
Merge remote-tracking branch 'origin' into fix
Routhleck Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax.interpreters.partial_eval import DynamicJaxprTracer

from ._misc import get_dtype



__all__ = [
'Quantity',
'Unit',
Expand Down Expand Up @@ -535,7 +538,7 @@ def get_unit(obj) -> Dimension:
The physical dimensions of the `obj`.
"""
try:
return obj.unit
return obj.dim
except AttributeError:
# The following is not very pretty, but it will avoid the costly
# isinstance check for the common types
Expand Down Expand Up @@ -981,25 +984,27 @@ def __init__(
value = jnp.array(value, dtype=dtype)
except ValueError:
raise TypeError("All elements must be convertible to a jax array")
dtype = dtype or get_dtype(value)

# array value
if isinstance(value, Quantity):
dtype = dtype or get_dtype(value)
self._dim = value.dim
self._value = jnp.array(value.value, dtype=dtype)
return

elif isinstance(value, (np.ndarray, jax.Array)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jnp.number, numbers.Number)):
dtype = dtype or get_dtype(value)
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)):
value = value

else:
raise TypeError(f"Invalid type for value: {type(value)}")
value = value

# value
self._value = (value if scale is None else (value * scale))
Expand Down Expand Up @@ -1330,9 +1335,13 @@ def isnan(self) -> jax.Array:
# ----------------------- #

def __repr__(self) -> str:
if isinstance(self.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
return f'{self.value} * {Quantity(1, dim=self.dim)}'
return self.repr_in_best_unit(python_code=True)

def __str__(self) -> str:
if isinstance(self.value, (jax.ShapeDtypeStruct, jax.core.ShapedArray, DynamicJaxprTracer)):
return f'{self.value} * {Quantity(1, dim=self.dim)}'
return self.repr_in_best_unit()

def __format__(self, format_spec: str) -> str:
Expand Down
40 changes: 20 additions & 20 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_get_dimensions():
assert is_scalar_type(np.array(5.0))
assert is_scalar_type(np.float32(5.0))
assert is_scalar_type(np.float64(5.0))
with pytest.raises(TypeError):
get_unit("a string")
# with pytest.raises(TypeError):
# get_unit("a string")
# wrong number of indices
with pytest.raises(TypeError):
get_or_create_dimension([1, 2, 3, 4, 5, 6])
Expand Down Expand Up @@ -551,15 +551,15 @@ def test_multiplication_division():
assert_quantity(q2 / q, np.asarray(q2) / np.asarray(q), second / volt)
assert_quantity(q * q2, np.asarray(q) * np.asarray(q2), volt * second)

# using unsupported objects should fail
with pytest.raises(TypeError):
q / "string"
with pytest.raises(TypeError):
"string" / q
with pytest.raises(TypeError):
"string" * q
with pytest.raises(TypeError):
q * "string"
# # using unsupported objects should fail
# with pytest.raises(TypeError):
# q / "string"
# with pytest.raises(TypeError):
# "string" / q
# with pytest.raises(TypeError):
# "string" * q
# with pytest.raises(TypeError):
# q * "string"


def test_addition_subtraction():
Expand Down Expand Up @@ -632,15 +632,15 @@ def test_addition_subtraction():
assert_quantity(q - np.float64(0), np.asarray(q), volt)
# assert_quantity(np.float64(0) - q, -np.asarray(q), volt)

# using unsupported objects should fail
with pytest.raises(TypeError):
"string" + q
with pytest.raises(TypeError):
q + "string"
with pytest.raises(TypeError):
q - "string"
with pytest.raises(TypeError):
"string" - q
# # using unsupported objects should fail
# with pytest.raises(TypeError):
# "string" + q
# with pytest.raises(TypeError):
# q + "string"
# with pytest.raises(TypeError):
# q - "string"
# with pytest.raises(TypeError):
# "string" - q


# def test_unary_operations():
Expand Down