-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Labels
Description
This function has a somewhat long and complicated sequence of conditional branches. AFAICT there is no direct test, so we should add one that covers all the expected cases to facilitate future maintenance and refactoring.
Bonus points for adding type hints
pymc/pymc/distributions/logprob.py
Lines 46 to 107 in 114e439
def _get_scaling(total_size, shape, ndim): | |
""" | |
Gets scaling constant for logp | |
Parameters | |
---------- | |
total_size: int or list[int] | |
shape: shape | |
shape to scale | |
ndim: int | |
ndim hint | |
Returns | |
------- | |
scalar | |
""" | |
if total_size is None: | |
coef = floatX(1) | |
elif isinstance(total_size, int): | |
if ndim >= 1: | |
denom = shape[0] | |
else: | |
denom = 1 | |
coef = floatX(total_size) / floatX(denom) | |
elif isinstance(total_size, (list, tuple)): | |
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)): | |
raise TypeError( | |
"Unrecognized `total_size` type, expected " | |
"int or list of ints, got %r" % total_size | |
) | |
if Ellipsis in total_size: | |
sep = total_size.index(Ellipsis) | |
begin = total_size[:sep] | |
end = total_size[sep + 1 :] | |
if Ellipsis in end: | |
raise ValueError( | |
"Double Ellipsis in `total_size` is restricted, got %r" % total_size | |
) | |
else: | |
begin = total_size | |
end = [] | |
if (len(begin) + len(end)) > ndim: | |
raise ValueError( | |
"Length of `total_size` is too big, " | |
"number of scalings is bigger that ndim, got %r" % total_size | |
) | |
elif (len(begin) + len(end)) == 0: | |
return floatX(1) | |
if len(end) > 0: | |
shp_end = shape[-len(end) :] | |
else: | |
shp_end = np.asarray([]) | |
shp_begin = shape[: len(begin)] | |
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None] | |
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None] | |
coefs = begin_coef + end_coef | |
coef = at.prod(coefs) | |
else: | |
raise TypeError( | |
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size | |
) | |
return at.as_tensor(floatX(coef)) |
Related to #4582