Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
# functions for checking
'check_dims',
'check_units',
'assign_units',
'fail_for_dimension_mismatch',
'fail_for_unit_mismatch',
'assert_quantity',
Expand Down Expand Up @@ -4453,6 +4454,78 @@ def new_f(*args, **kwds):
return do_check_units


@set_module_as('brainunit')
def assign_units(**au):
"""
Decorator to transform units of arguments passed to a function
"""

def do_assign_units(f):
@wraps(f)
def new_f(*args, **kwds):
newkeyset = kwds.copy()
arg_names = f.__code__.co_varnames[0: f.__code__.co_argcount]
for n, v in zip(arg_names, args[0: f.__code__.co_argcount]):
if n in au and v is not None:
specific_unit = au[n]
# if the specific unit is a boolean, just check and return
if specific_unit == bool:
if isinstance(v, bool):
newkeyset[n] = v
else:
raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'")

elif specific_unit == 1:
if isinstance(v, Quantity):
newkeyset[n] = v.to_decimal()
elif isinstance(v, (jax.Array, np.ndarray, int, float, complex)):
newkeyset[n] = v
else:
specific_unit = jax.typing.ArrayLike
raise TypeError(f"Function '{f.__name__}' expected a unitless Quantity object"
f"or {specific_unit} for argument '{n}' but got '{v}'")

elif isinstance(specific_unit, Unit):
if isinstance(v, Quantity):
v = v.to_decimal(specific_unit)
newkeyset[n] = v
else:
raise TypeError(
f"Function '{f.__name__}' expected a Quantity object for argument '{n}' but got '{v}'"
)
else:
raise TypeError(
f"Function '{f.__name__}' expected a target unit object or"
f" a Number, boolean object for checking, but got '{specific_unit}'"
)
else:
newkeyset[n] = v

result = f(**newkeyset)
if "result" in au:
if isinstance(au["result"], Callable) and au["result"] != bool:
expected_result = au["result"](*[get_unit(a) for a in args])
else:
expected_result = au["result"]

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)

result = jax.tree.map(
partial(_assign_unit, f), result, expected_result
)
return result

return new_f

return do_assign_units

def _check_unit(f, val, unit):
unit = UNITLESS if unit is None else unit
if not has_same_unit(val, unit):
Expand All @@ -4464,6 +4537,11 @@ def _check_unit(f, val, unit):
)
raise UnitMismatchError(error_message, get_unit(val))

def _assign_unit(f, val, unit):
if unit is None or unit == bool or unit == 1:
return val
return Quantity(val, unit=unit)


def _is_quantity(x):
return isinstance(x, Quantity)
81 changes: 81 additions & 0 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,87 @@ def d_function2(true_result):
with pytest.raises(u.UnitMismatchError):
d_function2(2)

def test_assign_units(self):
"""
Test the assign_units decorator
"""

@u.assign_units(v=volt)
def a_function(v, x):
"""
v has to have units of volt, x can have any (or no) unit.
"""
return v

# Try correct units
assert a_function(3 * mV, 5 * second) == (3 * mV).to_decimal(volt)
assert a_function(3 * volt, 5 * second) == (3 * volt).to_decimal(volt)
assert a_function(5 * volt, "something") == (5 * volt).to_decimal(volt)
assert_quantity(a_function([1, 2, 3] * volt, None), ([1, 2, 3] * volt).to_decimal(volt))

# Try incorrect units
with pytest.raises(u.UnitMismatchError):
a_function(5 * second, None)
with pytest.raises(TypeError):
a_function(5, None)
with pytest.raises(TypeError):
a_function(object(), None)

@u.assign_units(result=second)
def b_function():
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
return 5

# Should work (returns second)
assert b_function() == 5 * second

@u.assign_units(a=bool, b=1, result=bool)
def c_function(a, b):
if a:
return b > 0
else:
return b

assert c_function(True, 1)
assert not c_function(True, -1)
with pytest.raises(TypeError):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# Multiple results
@u.assign_units(result=(second, volt))
def d_function():
return 5, 3

# Should work (returns second)
assert d_function()[0] == 5 * second
assert d_function()[1] == 3 * volt

# Multiple results
@u.assign_units(result={'u': second, 'v': (volt, metre)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5, 'v': (3, 2)}
elif true_result == 1:
return 3, 5
else:
return 3, 5

# Should work (returns dict)
d_function2(0)
# Should fail (returns tuple)
with pytest.raises(TypeError):
d_function2(1)



def test_str_repr():
"""
Expand Down
Loading