Skip to content

Add test for _get_scaling #5515

Closed
Closed
@ricardoV94

Description

@ricardoV94

This function has a somewhat long and complicated sequence of conditional branches. AFAICT there is no direct test, so we should add one that covers all the expected cases to facilitate future maintenance and refactoring.

Bonus points for adding type hints

def _get_scaling(total_size, shape, ndim):
"""
Gets scaling constant for logp
Parameters
----------
total_size: int or list[int]
shape: shape
shape to scale
ndim: int
ndim hint
Returns
-------
scalar
"""
if total_size is None:
coef = floatX(1)
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
else:
denom = 1
coef = floatX(total_size) / floatX(denom)
elif isinstance(total_size, (list, tuple)):
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
raise TypeError(
"Unrecognized `total_size` type, expected "
"int or list of ints, got %r" % total_size
)
if Ellipsis in total_size:
sep = total_size.index(Ellipsis)
begin = total_size[:sep]
end = total_size[sep + 1 :]
if Ellipsis in end:
raise ValueError(
"Double Ellipsis in `total_size` is restricted, got %r" % total_size
)
else:
begin = total_size
end = []
if (len(begin) + len(end)) > ndim:
raise ValueError(
"Length of `total_size` is too big, "
"number of scalings is bigger that ndim, got %r" % total_size
)
elif (len(begin) + len(end)) == 0:
return floatX(1)
if len(end) > 0:
shp_end = shape[-len(end) :]
else:
shp_end = np.asarray([])
shp_begin = shape[: len(begin)]
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = at.prod(coefs)
else:
raise TypeError(
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
)
return at.as_tensor(floatX(coef))

Related to #4582

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions