Skip to content

Commit

Permalink
Merge pull request #63 from iguinn/main
Browse files Browse the repository at this point in the history
Improved handling of coordinate variables
  • Loading branch information
iguinn authored Mar 24, 2024
2 parents 911f287 + 74dfe76 commit 53e478f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 46 deletions.
120 changes: 77 additions & 43 deletions src/dspeed/processing_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
}


# helper function to tell if an object is found in the unit registry
def is_in_pint(unit):
return isinstance(unit, (Unit, Quantity)) or (unit and unit in ureg)


@dataclass
class CoordinateGrid:
"""Helper class that describes a system of units, consisting of a period
Expand Down Expand Up @@ -257,7 +262,7 @@ def get_buffer(self, unit: str | Unit = None) -> np.ndarray:
# if no unit is given, use the native unit/coordinate grid
if unit is None:
unit = self.grid if self.is_coord else self.unit
if not isinstance(unit, CoordinateGrid) and unit and unit in ureg:
if not isinstance(unit, CoordinateGrid) and is_in_pint(unit):
unit = CoordinateGrid(unit)

if isinstance(self._buffer, np.ndarray):
Expand All @@ -267,18 +272,14 @@ def get_buffer(self, unit: str | Unit = None) -> np.ndarray:
elif unit is not None:
self.grid = CoordinateGrid(unit)

if unit is None or not (
isinstance(unit, (Unit, Quantity, CoordinateGrid)) or unit in ureg
):
if not (isinstance(unit, CoordinateGrid) or is_in_pint(unit)):
# buffer cannot be converted so return
return self._buffer
else:
# buffer can be converted, so make it a list of buffers
self._buffer = [(self._buffer, unit)]

if unit is None or not (
isinstance(unit, (Unit, Quantity, CoordinateGrid)) or unit in ureg
):
if not isinstance(unit, CoordinateGrid) and not is_in_pint(unit):
return self._buffer[0][0]

# check if coordinate conversion has been done already
Expand Down Expand Up @@ -488,20 +489,16 @@ def set_constant(
"""

param = self.get_variable(varname)
assert param.is_constant or param._buffer is None
param.is_constant = True
assert param.is_const or param._buffer is None
param.is_const = True

if isinstance(val, Quantity):
unit = val.unit
val = val.magnitude

val = np.array(val, dtype=dtype)

param.update_auto(
shape=val.shape,
dtype=val.dtype,
unit=unit,
)
param.update_auto(shape=val.shape, dtype=val.dtype, unit=unit, is_coord=False)
np.copyto(param.get_buffer(), val, casting="unsafe")
log.debug(f"set constant: {param.description()} = {val}")
return param
Expand Down Expand Up @@ -828,8 +825,27 @@ def _parse_expr(

name = "(" + op_form.format(str(lhs), str(rhs)) + ")"
if isinstance(lhs, ProcChainVar) and isinstance(rhs, ProcChainVar):
# TODO: handle units/coords; for now make them match lhs
out = ProcChainVar(self, name, is_coord=lhs.is_coord)
if is_in_pint(lhs.unit) and is_in_pint(rhs.unit):
unit = op(Quantity(lhs.unit), Quantity(rhs.unit)).u
if unit == ureg.dimensionless:
unit = None
elif lhs.unit is not None and rhs.unit is not None:
unit = op_form.format(str(lhs.unit), str(rhs.unit))
elif lhs.unit is not None:
unit = lhs.unit
else:
unit = rhs.unit
# If both vars are coordinates, this is probably not a coord.
# If one var is a coord, this is probably a coord
out = ProcChainVar(
self,
name,
grid=None if lhs.is_coord and rhs.is_coord else auto,
is_coord=False
if lhs.is_coord is True and rhs.is_coord is True
else auto,
unit=unit,
)
elif isinstance(lhs, ProcChainVar):
out = ProcChainVar(
self,
Expand Down Expand Up @@ -880,7 +896,7 @@ def _parse_expr(
val = self._parse_expr(node.value, expr, dry_run, var_name_list)
if val is None:
return None
if not isinstance(val, ProcChainVar):
if not isinstance(val, ProcChainVar) or not len(val.shape) > 0:
raise ProcessingChainError("Cannot apply subscript to", node.value)

def get_index(slice_value):
Expand All @@ -898,11 +914,11 @@ def get_index(slice_value):
return round_ret
return int(ret)

if isinstance(node.slice, ast.Index):
index = get_index(node.slice.value)
out_buf = val[..., index]
out_name = (f"{str(val)}[{index}]",)
out_grid = None
if isinstance(node.slice, ast.Constant):
index = get_index(node.slice)
out_buf = val.buffer[..., index]
out_name = f"{str(val)}[{index}]"
out_grid = val.grid if val.is_coord else None

elif isinstance(node.slice, ast.Slice):
sl = slice(
Expand Down Expand Up @@ -1283,11 +1299,14 @@ def __init__(
d.strip() for d in dims.split(",") if d
]
arr_dims = list(param.shape)
arr_grid = (
param.grid
if isinstance(param, ProcChainVar) and param.grid is not auto
else None
)
if (
isinstance(param, ProcChainVar)
and param.grid is not auto
and not param.is_coord
):
arr_grid = param.grid
else:
arr_grid = None
if not grid:
grid = arr_grid

Expand Down Expand Up @@ -1411,8 +1430,7 @@ def __init__(
unit = str(grid.period.u)
this_grid = grid
elif (
isinstance(param.unit, str)
and param.unit in ureg
is_in_pint(param.unit)
and grid is not None
and ureg.is_compatible_with(grid.period, param.unit)
):
Expand Down Expand Up @@ -1721,19 +1739,28 @@ def __init__(self, io_array: np.ArrayOfEqualSizedArrays, var: ProcChainVar) -> N
unit = io_array.attrs.get("units", None)
var.update_auto(dtype=io_array.dtype, shape=io_array.nda.shape[1:], unit=unit)

if isinstance(var.unit, CoordinateGrid):
if isinstance(var.unit, (CoordinateGrid, Quantity, Unit)):
if isinstance(var.unit, CoordinateGrid):
var_u = var.unit.period.u
elif isinstance(var.unit, Quantity):
var_u = var.unit.u
else:
var_u = var.unit

if unit is None:
unit = var.unit.period.u
elif ureg.is_compatible_with(var.unit.period, unit):
unit = var_u
elif ureg.is_compatible_with(var_u, unit):
unit = ureg.Quantity(unit).u
else:
raise ProcessingChainError(
f"LGDO array and variable {var} have incompatible units "
f"({var.unit.period.u} and {unit})"
f"({var_u} and {unit})"
)
elif isinstance(var.unit, str) and unit is None:
unit = var.unit

if unit is None and var.unit is not None:
io_array.attrs["units"] = str(var.unit)
if "units" not in io_array.attrs and unit is not None:
io_array.attrs["units"] = str(unit)

self.io_array = io_array
self.raw_buf = io_array.nda
Expand Down Expand Up @@ -1781,22 +1808,29 @@ def __init__(self, io_vov: lgdo.VectorOfVectors, var: ProcChainVar) -> None:

unit = io_vov.attrs.get("units", None)
var.update_auto(dtype=io_vov.dtype, shape=10, unit=unit)
if var.vector_len is None:
var.vector_len = (f"{var.name}_len",)

if isinstance(var.unit, CoordinateGrid):
if isinstance(var.unit, (CoordinateGrid, Quantity, Unit)):
if isinstance(var.unit, CoordinateGrid):
var_u = var.unit.period.u
elif isinstance(var.unit, Quantity):
var_u = var.unit.u
else:
var_u = var.unit

if unit is None:
unit = var.unit.period.u
elif ureg.is_compatible_with(var.unit.period, unit):
unit = var_u
elif ureg.is_compatible_with(var_u, unit):
unit = ureg.Quantity(unit).u
else:
raise ProcessingChainError(
f"LGDO array and variable {var} have incompatible units "
f"({var.unit.period.u} and {unit})"
f"({var_u} and {unit})"
)
elif isinstance(var.unit, str) and unit is None:
unit = var.unit

if unit is None and var.unit is not None:
io_vov.attrs["units"] = str(var.unit)
if "units" not in io_vov.attrs and unit is not None:
io_vov.attrs["units"] = str(unit)

self.io_vov = io_vov
self.raw_buf = io_vov.flattened_data
Expand Down
8 changes: 5 additions & 3 deletions src/dspeed/vis/waveform_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,12 @@ def find_entry(
elif isinstance(
data, (lgdo.Array, lgdo.ArrayOfEqualSizedArrays, lgdo.VectorOfVectors)
):
if isinstance(data, lgdo.Array):
vals = [data.nda[i_tb]]
if isinstance(
data, (lgdo.ArrayOfEqualSizedArrays, lgdo.VectorOfVectors)
):
vals = list(data.nda[i_tb])
else:
vals = data[i_tb]
vals = [data.nda[i_tb]]

unit = data.attrs.get("units", None)
if unit and unit in ureg and ureg.is_compatible_with(unit, self.x_unit):
Expand Down

0 comments on commit 53e478f

Please sign in to comment.