@@ -3088,6 +3088,88 @@ namespace xt
30883088 return f;
30893089 }
30903090
3091+ namespace detail
3092+ {
3093+ template <class E1 , class E2 >
3094+ auto calculate_discontinuity (E1 && discontinuity, E2 &&)
3095+ {
3096+ return discontinuity;
3097+ }
3098+
3099+ template <class E2 >
3100+ auto calculate_discontinuity (xt::placeholders::xtuph, E2 && period)
3101+ {
3102+ return 0.5 * period;
3103+ }
3104+
3105+ template <class E1 , class E2 >
3106+ auto
3107+ calculate_interval (E2 && period, typename std::enable_if<std::is_integral<E1 >::value, E1 >::type* = 0 )
3108+ {
3109+ auto interval_high = 0.5 * period;
3110+ uint64_t remainder = static_cast <uint64_t >(period) % 2 ;
3111+ auto boundary_ambiguous = (remainder == 0 );
3112+ return std::make_tuple (interval_high, boundary_ambiguous);
3113+ }
3114+
3115+ template <class E1 , class E2 >
3116+ auto
3117+ calculate_interval (E2 && period, typename std::enable_if<std::is_floating_point<E1 >::value, E1 >::type* = 0 )
3118+ {
3119+ auto interval_high = 0.5 * period;
3120+ auto boundary_ambiguous = true ;
3121+ return std::make_tuple (interval_high, boundary_ambiguous);
3122+ }
3123+ }
3124+
3125+ /* *
3126+ * @ingroup basic_functions
3127+ * @brief Unwrap by taking the complement of large deltas with respect to the period
3128+ * @details https://numpy.org/doc/stable/reference/generated/numpy.unwrap.html
3129+ * @param p Input array.
3130+ * @param discontinuity
3131+ * Maximum discontinuity between values, default is `period / 2`.
3132+ * Values below `period / 2` are treated as if they were `period / 2`.
3133+ * To have an effect different from the default, use `discontinuity > period / 2`.
3134+ * @param axis Axis along which unwrap will operate, default: the last axis.
3135+ * @param period Size of the range over which the input wraps. Default: \f$ 2 \pi \f$.
3136+ */
3137+
3138+ template <class E1 , class E2 = xt::placeholders::xtuph, class E3 = double >
3139+ inline auto unwrap (
3140+ E1 && p,
3141+ E2 discontinuity = xnone(),
3142+ std::ptrdiff_t axis = -1,
3143+ E3 period = 2.0 * xt::numeric_constants<double>::PI
3144+ )
3145+ {
3146+ auto discont = detail::calculate_discontinuity (discontinuity, period);
3147+ using value_type = typename std::decay_t <E1 >::value_type;
3148+ std::size_t saxis = normalize_axis (p.dimension (), axis);
3149+ auto dd = diff (p, 1 , axis);
3150+ xstrided_slice_vector slice (p.dimension (), all ());
3151+ slice[saxis] = range (1 , xnone ());
3152+ auto interval_tuple = detail::calculate_interval<value_type>(period);
3153+ auto interval_high = std::get<0 >(interval_tuple);
3154+ auto boundary_ambiguous = std::get<1 >(interval_tuple);
3155+ auto interval_low = -interval_high;
3156+ auto ddmod = xt::eval (xt::fmod (xt::fmod (dd - interval_low, period) + period, period) + interval_low);
3157+ if (boundary_ambiguous)
3158+ {
3159+ // for `mask = (abs(dd) == period/2)`, the above line made
3160+ // `ddmod[mask] == -period/2`. correct these such that
3161+ // `ddmod[mask] == sign(dd[mask])*period/2`.
3162+ auto boolmap = xt::equal (ddmod, interval_low) && (xt::greater (dd, 0.0 ));
3163+ ddmod = xt::where (boolmap, interval_high, ddmod);
3164+ }
3165+ auto ph_correct = xt::eval (ddmod - dd);
3166+ ph_correct = xt::where (xt::abs (dd) < discont, 0.0 , ph_correct);
3167+ E1 up (p);
3168+ strided_view (up, slice) = strided_view (p, slice)
3169+ + xt::cumsum (ph_correct, static_cast <std::ptrdiff_t >(saxis));
3170+ return up;
3171+ }
3172+
30913173 /* *
30923174 * @ingroup basic_functions
30933175 * @brief Returns the one-dimensional piecewise linear interpolant to a function with given discrete data
0 commit comments