Skip to content

Fix CubicInterpolation for the case where ys = {}. Add unit tests for… #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

allen-adastra
Copy link
Contributor

… both cubic and linear interpolation checking this case.

… both cubic and linear interpolation checking this case
@allen-adastra allen-adastra force-pushed the empty_trees_for_interp branch from 8bc215c to db7e8d4 Compare January 25, 2024 16:26
@patrick-kidger
Copy link
Owner

Thank you for the fix! What error did you get with the existing code in this scenario, by the way?

If I understand correctly this will still do the wrong thing for e.g. ys = [] though -- i.e. other zero-leaf PyTrees, though?

…d unit test case to test_interpolation_classes
@allen-adastra
Copy link
Contributor Author

Good point; made it work with [] too.

Error without the fix:

ts = Traced<ShapedArray(float64[8])>with<DynamicJaxprTrace(level=1/0)>, ys = {}

    @eqx.filter_jit
    def backward_hermite_coefficients(
        ts: Real[Array, " times"],
        ys: PyTree[Shaped[Array, "times ?*shape"], "Y"],
        *,
        deriv0: Optional[PyTree[Shaped[Array, "?#*shape"], "Y"]] = None,
        replace_nans_at_start: Optional[PyTree[Shaped[ArrayLike, "?#*shape"], "Y"]] = None,
        fill_forward_nans_at_end: bool = False,
    ) -> tuple[
        PyTree[Shaped[Array, "times-1 ?*shape"], "Y"],
        PyTree[Shaped[Array, "times-1 ?*shape"], "Y"],
        PyTree[Shaped[Array, "times-1 ?*shape"], "Y"],
        PyTree[Shaped[Array, "times-1 ?*shape"], "Y"],
    ]:
        """Interpolates the data with a cubic spline. Specifically, this calculates the
        coefficients for Hermite cubic splines with backward differences.
    
        This is most useful prior to using [`diffrax.CubicInterpolation`][] to create a
        smooth path from discrete observations.
    
        ??? cite "Reference"
    
            Hermite cubic splines with backward differences were introduced in this paper:
    
            ```bibtex
            @article{morrill2021cdeonline,
                    title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline
                           {P}rediction {T}asks},
                    author={Morrill, James and Kidger, Patrick and Yang, Lingyi and
                            Lyons, Terry},
                    journal={arXiv:2106.11028},
                    year={2021}
            }
            ```
    
        **Arguments:**
    
        - `ts`: The time of each observation.
        - `ys`: The observations themselves. Should use `NaN` to indicate missing data.
        - `deriv0`: The derivative at `ts[0]`. If not passed then a forward difference of
            `(ys[i] - ys[0]) / (ts[i] - ts[0])` is used, where `i` is the index of the
            first non-`NaN` element of `ys`.
        - `fill_forward_nans_at_end`: By default `NaN` values at the end (with no non-`NaN`
            value after them) are left as `NaN`s. If this is set then they will instead
            be filled in using the last non-`NaN` value prior to fitting the cubic spline.
        - `replace_nans_at_start`: By default `NaN` values at the start (with no non-`NaN`
            value before them) are left as `NaN`s. If this is passed then it will be used
            to fill in such `NaN` values.
    
        **Returns:**
    
        The coefficients of the Hermite cubic spline. If `ts` has length $T$ then the
        coefficients will be of length $T - 1$, covering each of the intervals from `ts[0]`
        to `ts[1]`, and `ts[1]` to `ts[2]` etc.
        """
    
        ts = _check_ts(ts)
        fn = ft.partial(_backward_hermite_coefficients, fill_forward_nans_at_end, ts)
    
        # if len(jtu.tree_leaves(ys)) == 0:
        #     return ({}, {}, {}, {})
    
        if deriv0 is None:
            if replace_nans_at_start is None:
                coeffs = jtu.tree_map(fn, ys)
            else:
                _fn = lambda ys, replace_nans_at_start: fn(ys, None, replace_nans_at_start)
                coeffs = jtu.tree_map(_fn, ys, replace_nans_at_start)
        else:
            if replace_nans_at_start is None:
                coeffs = jtu.tree_map(fn, ys, deriv0)
            else:
                coeffs = jtu.tree_map(fn, ys, deriv0, replace_nans_at_start)
        ys_treedef = jtu.tree_structure(ys)
        coeffs_treedef = jtu.tree_structure((0, 0, 0, 0))
>       return jtu.tree_transpose(ys_treedef, coeffs_treedef, coeffs)
E       ValueError: Too few leaves for PyTreeDef; expected 4, got 0

diffrax/_global_interpolation.py:778: ValueError

@patrick-kidger patrick-kidger merged commit 8aafefc into patrick-kidger:main Jan 29, 2024
@patrick-kidger
Copy link
Owner

Wonderful. LGTM, thank you for the fix!

@allen-adastra allen-adastra deleted the empty_trees_for_interp branch February 1, 2024 18:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants