Skip to content

Fixes for 2021 tutorials. #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 177 additions & 44 deletions src/vector/_backends/awkward_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,26 +1012,49 @@ class MomentumRecord4D(MomentumAwkward4D, ak.Record):
# implementation of behaviors in Numba ########################################


def _aztype_of(recordarraytype: typing.Any) -> typing.Any:
def _aztype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
import numba

cls: typing.Union[
typing.Type[AzimuthalObjectXY],
typing.Type[AzimuthalObjectRhoPhi],
]

try:
x_index = recordarraytype.recordlookup.index("x")
except ValueError:
x_index = None
try:
y_index = recordarraytype.recordlookup.index("y")
except ValueError:
y_index = None
try:
rho_index = recordarraytype.recordlookup.index("rho")
except ValueError:
rho_index = None
x_index = None
y_index = None
rho_index = None
phi_index = None

if is_momentum:
try:
x_index = recordarraytype.recordlookup.index("px")
except ValueError:
x_index = None
if x_index is None:
try:
x_index = recordarraytype.recordlookup.index("x")
except ValueError:
x_index = None
if is_momentum:
try:
y_index = recordarraytype.recordlookup.index("py")
except ValueError:
y_index = None
if y_index is None:
try:
y_index = recordarraytype.recordlookup.index("y")
except ValueError:
y_index = None
if is_momentum:
try:
rho_index = recordarraytype.recordlookup.index("pt")
except ValueError:
rho_index = None
if rho_index is None:
try:
rho_index = recordarraytype.recordlookup.index("rho")
except ValueError:
rho_index = None
try:
phi_index = recordarraytype.recordlookup.index("phi")
except ValueError:
Expand All @@ -1047,6 +1070,11 @@ def _aztype_of(recordarraytype: typing.Any) -> typing.Any:
coord2 = recordarraytype.contenttypes[phi_index].arraytype.dtype
cls = AzimuthalObjectRhoPhi

elif is_momentum:
raise numba.TypingError(
f"{recordarraytype} is missing azimuthal fields: px/py (x/y) or pt/phi (rho/phi)"
)

else:
raise numba.TypingError(
f"{recordarraytype} is missing azimuthal fields: x/y or rho/phi"
Expand All @@ -1055,7 +1083,7 @@ def _aztype_of(recordarraytype: typing.Any) -> typing.Any:
return numba.typeof(cls(coord1.cast_python_value(0), coord2.cast_python_value(0)))


def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
def _ltype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
import numba

cls: typing.Union[
Expand All @@ -1064,10 +1092,20 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
typing.Type[LongitudinalObjectEta],
]

try:
z_index = recordarraytype.recordlookup.index("z")
except ValueError:
z_index = None
z_index = None
theta_index = None
eta_index = None

if is_momentum:
try:
z_index = recordarraytype.recordlookup.index("pz")
except ValueError:
z_index = None
if z_index is None:
try:
z_index = recordarraytype.recordlookup.index("z")
except ValueError:
z_index = None
try:
theta_index = recordarraytype.recordlookup.index("theta")
except ValueError:
Expand All @@ -1089,6 +1127,11 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
coord1 = recordarraytype.contenttypes[eta_index].arraytype.dtype
cls = LongitudinalObjectEta

elif is_momentum:
raise numba.TypingError(
f"{recordarraytype} is missing longitudinal fields: pz (z) or theta or eta"
)

else:
raise numba.TypingError(
f"{recordarraytype} is missing longitudinal fields: z or theta or eta"
Expand All @@ -1097,22 +1140,57 @@ def _ltype_of(recordarraytype: typing.Any) -> typing.Any:
return numba.typeof(cls(coord1.cast_python_value(0)))


def _ttype_of(recordarraytype: typing.Any) -> typing.Any:
def _ttype_of(recordarraytype: typing.Any, is_momentum: bool) -> typing.Any:
import numba

cls: typing.Union[
typing.Type[TemporalObjectT],
typing.Type[TemporalObjectTau],
]

try:
t_index = recordarraytype.recordlookup.index("t")
except ValueError:
t_index = None
try:
tau_index = recordarraytype.recordlookup.index("tau")
except ValueError:
tau_index = None
t_index = None
tau_index = None

if is_momentum:
try:
t_index = recordarraytype.recordlookup.index("E")
except ValueError:
t_index = None
if is_momentum and t_index is None:
try:
t_index = recordarraytype.recordlookup.index("e")
except ValueError:
t_index = None
if is_momentum and t_index is None:
try:
t_index = recordarraytype.recordlookup.index("energy")
except ValueError:
t_index = None
if t_index is None:
try:
t_index = recordarraytype.recordlookup.index("t")
except ValueError:
t_index = None
if is_momentum:
try:
tau_index = recordarraytype.recordlookup.index("M")
except ValueError:
tau_index = None
if is_momentum and tau_index is None:
try:
tau_index = recordarraytype.recordlookup.index("m")
except ValueError:
tau_index = None
if is_momentum and tau_index is None:
try:
tau_index = recordarraytype.recordlookup.index("mass")
except ValueError:
tau_index = None
if tau_index is None:
try:
tau_index = recordarraytype.recordlookup.index("tau")
except ValueError:
tau_index = None

if t_index is not None:
coord1 = recordarraytype.contenttypes[t_index].arraytype.dtype
Expand All @@ -1122,6 +1200,11 @@ def _ttype_of(recordarraytype: typing.Any) -> typing.Any:
coord1 = recordarraytype.contenttypes[tau_index].arraytype.dtype
cls = TemporalObjectTau

elif is_momentum:
raise numba.TypingError(
f"{recordarraytype} is missing temporal fields: E/e/energy (t) or M/m/mass (tau)"
)

else:
raise numba.TypingError(
f"{recordarraytype} is missing temporal fields: t or tau"
Expand All @@ -1135,80 +1218,112 @@ def _numba_typer_Vector2D(viewtype: typing.Any) -> typing.Any:

# These clearly exist, a bug somewhere, but ignoring them for now
return vector._backends.numba_object.VectorObject2DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type)
_aztype_of(viewtype.arrayviewtype.type, False)
)


def _numba_typer_Vector3D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.VectorObject3DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, False),
_ltype_of(viewtype.arrayviewtype.type, False),
)


def _numba_typer_Vector4D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.VectorObject4DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_ttype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, False),
_ltype_of(viewtype.arrayviewtype.type, False),
_ttype_of(viewtype.arrayviewtype.type, False),
)


def _numba_typer_Momentum2D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.MomentumObject2DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type)
_aztype_of(viewtype.arrayviewtype.type, True)
)


def _numba_typer_Momentum3D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.MomentumObject3DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, True),
_ltype_of(viewtype.arrayviewtype.type, True),
)


def _numba_typer_Momentum4D(viewtype: typing.Any) -> typing.Any:
import vector._backends.numba_object

return vector._backends.numba_object.MomentumObject4DType( # type: ignore
_aztype_of(viewtype.arrayviewtype.type),
_ltype_of(viewtype.arrayviewtype.type),
_ttype_of(viewtype.arrayviewtype.type),
_aztype_of(viewtype.arrayviewtype.type, True),
_ltype_of(viewtype.arrayviewtype.type, True),
_ttype_of(viewtype.arrayviewtype.type, True),
)


def _numba_lower(
context: typing.Any, builder: typing.Any, sig: typing.Any, args: typing.Any
) -> typing.Any:
from vector._backends.numba_object import ( # type: ignore
_awkward_numba_E,
_awkward_numba_e,
_awkward_numba_energy,
_awkward_numba_eta,
_awkward_numba_M,
_awkward_numba_m,
_awkward_numba_mass,
_awkward_numba_ptphi,
_awkward_numba_pxpy,
_awkward_numba_pxy,
_awkward_numba_pz,
_awkward_numba_rhophi,
_awkward_numba_t,
_awkward_numba_tau,
_awkward_numba_theta,
_awkward_numba_xpy,
_awkward_numba_xy,
_awkward_numba_z,
)

vectorcls = sig.return_type.instance_class

fields = sig.args[0].arrayviewtype.type.recordlookup

if issubclass(vectorcls, (VectorObject2D, VectorObject3D, VectorObject4D)):
if issubclass(sig.return_type.azimuthaltype.instance_class, AzimuthalXY):
azimuthal = _awkward_numba_xy
if "x" in fields and "y" in fields:
azimuthal = _awkward_numba_xy
elif "x" in fields and "py" in fields:
azimuthal = _awkward_numba_xpy
elif "px" in fields and "y" in fields:
azimuthal = _awkward_numba_pxy
elif "px" in fields and "py" in fields:
azimuthal = _awkward_numba_pxpy
else:
raise AssertionError
elif issubclass(sig.return_type.azimuthaltype.instance_class, AzimuthalRhoPhi):
azimuthal = _awkward_numba_rhophi
if "rho" in fields and "phi" in fields:
azimuthal = _awkward_numba_rhophi
elif "pt" in fields and "phi" in fields:
azimuthal = _awkward_numba_ptphi
else:
raise AssertionError

if issubclass(vectorcls, (VectorObject3D, VectorObject4D)):
if issubclass(sig.return_type.longitudinaltype.instance_class, LongitudinalZ):
longitudinal = _awkward_numba_z
if "z" in fields:
longitudinal = _awkward_numba_z
elif "pz" in fields:
longitudinal = _awkward_numba_pz
else:
raise AssertionError
elif issubclass(
sig.return_type.longitudinaltype.instance_class, LongitudinalTheta
):
Expand All @@ -1220,9 +1335,27 @@ def _numba_lower(

if issubclass(vectorcls, VectorObject4D):
if issubclass(sig.return_type.temporaltype.instance_class, TemporalT):
temporal = _awkward_numba_t
if "t" in fields:
temporal = _awkward_numba_t
elif "E" in fields:
temporal = _awkward_numba_E
elif "e" in fields:
temporal = _awkward_numba_e
elif "energy" in fields:
temporal = _awkward_numba_energy
else:
raise AssertionError
elif issubclass(sig.return_type.temporaltype.instance_class, TemporalTau):
temporal = _awkward_numba_tau
if "tau" in fields:
temporal = _awkward_numba_tau
elif "M" in fields:
temporal = _awkward_numba_M
elif "m" in fields:
temporal = _awkward_numba_m
elif "mass" in fields:
temporal = _awkward_numba_mass
else:
raise AssertionError

if issubclass(vectorcls, VectorObject2D):

Expand Down
Loading