Skip to content

Commit 0e6b11a

Browse files
authored
Added unwrap (#2710)
1 parent 4c71ce3 commit 0e6b11a

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

include/xtensor/xmath.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3088,6 +3088,88 @@ namespace xt
30883088
return f;
30893089
}
30903090

3091+
namespace detail
3092+
{
3093+
template <class E1, class E2>
3094+
auto calculate_discontinuity(E1&& discontinuity, E2&&)
3095+
{
3096+
return discontinuity;
3097+
}
3098+
3099+
template <class E2>
3100+
auto calculate_discontinuity(xt::placeholders::xtuph, E2&& period)
3101+
{
3102+
return 0.5 * period;
3103+
}
3104+
3105+
template <class E1, class E2>
3106+
auto
3107+
calculate_interval(E2&& period, typename std::enable_if<std::is_integral<E1>::value, E1>::type* = 0)
3108+
{
3109+
auto interval_high = 0.5 * period;
3110+
uint64_t remainder = static_cast<uint64_t>(period) % 2;
3111+
auto boundary_ambiguous = (remainder == 0);
3112+
return std::make_tuple(interval_high, boundary_ambiguous);
3113+
}
3114+
3115+
template <class E1, class E2>
3116+
auto
3117+
calculate_interval(E2&& period, typename std::enable_if<std::is_floating_point<E1>::value, E1>::type* = 0)
3118+
{
3119+
auto interval_high = 0.5 * period;
3120+
auto boundary_ambiguous = true;
3121+
return std::make_tuple(interval_high, boundary_ambiguous);
3122+
}
3123+
}
3124+
3125+
/**
3126+
* @ingroup basic_functions
3127+
* @brief Unwrap by taking the complement of large deltas with respect to the period
3128+
* @details https://numpy.org/doc/stable/reference/generated/numpy.unwrap.html
3129+
* @param p Input array.
3130+
* @param discontinuity
3131+
* Maximum discontinuity between values, default is `period / 2`.
3132+
* Values below `period / 2` are treated as if they were `period / 2`.
3133+
* To have an effect different from the default, use `discontinuity > period / 2`.
3134+
* @param axis Axis along which unwrap will operate, default: the last axis.
3135+
* @param period Size of the range over which the input wraps. Default: \f$ 2 \pi \f$.
3136+
*/
3137+
3138+
template <class E1, class E2 = xt::placeholders::xtuph, class E3 = double>
3139+
inline auto unwrap(
3140+
E1&& p,
3141+
E2 discontinuity = xnone(),
3142+
std::ptrdiff_t axis = -1,
3143+
E3 period = 2.0 * xt::numeric_constants<double>::PI
3144+
)
3145+
{
3146+
auto discont = detail::calculate_discontinuity(discontinuity, period);
3147+
using value_type = typename std::decay_t<E1>::value_type;
3148+
std::size_t saxis = normalize_axis(p.dimension(), axis);
3149+
auto dd = diff(p, 1, axis);
3150+
xstrided_slice_vector slice(p.dimension(), all());
3151+
slice[saxis] = range(1, xnone());
3152+
auto interval_tuple = detail::calculate_interval<value_type>(period);
3153+
auto interval_high = std::get<0>(interval_tuple);
3154+
auto boundary_ambiguous = std::get<1>(interval_tuple);
3155+
auto interval_low = -interval_high;
3156+
auto ddmod = xt::eval(xt::fmod(xt::fmod(dd - interval_low, period) + period, period) + interval_low);
3157+
if (boundary_ambiguous)
3158+
{
3159+
// for `mask = (abs(dd) == period/2)`, the above line made
3160+
//`ddmod[mask] == -period/2`. correct these such that
3161+
//`ddmod[mask] == sign(dd[mask])*period/2`.
3162+
auto boolmap = xt::equal(ddmod, interval_low) && (xt::greater(dd, 0.0));
3163+
ddmod = xt::where(boolmap, interval_high, ddmod);
3164+
}
3165+
auto ph_correct = xt::eval(ddmod - dd);
3166+
ph_correct = xt::where(xt::abs(dd) < discont, 0.0, ph_correct);
3167+
E1 up(p);
3168+
strided_view(up, slice) = strided_view(p, slice)
3169+
+ xt::cumsum(ph_correct, static_cast<std::ptrdiff_t>(saxis));
3170+
return up;
3171+
}
3172+
30913173
/**
30923174
* @ingroup basic_functions
30933175
* @brief Returns the one-dimensional piecewise linear interpolant to a function with given discrete data

test/test_xmath.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,4 +923,48 @@ namespace xt
923923

924924
EXPECT_EQ(result, expected);
925925
}
926+
927+
TEST(xmath, unwrap)
928+
{
929+
{
930+
// {0, pi / 4, pi / 2, -pi / 4, 0}
931+
xt::xarray<double> expected = {0., 0.78539816, 1.57079633, -0.78539816, 0};
932+
auto pi = xt::numeric_constants<double>::PI;
933+
xt::xarray<double> phase = xt::linspace<double>(0, pi, 5);
934+
xt::view(phase, xt::range(3, xt::xnone())) += pi;
935+
auto unwrapped = xt::unwrap(phase);
936+
EXPECT_TRUE(xt::allclose(expected, unwrapped));
937+
}
938+
{
939+
xt::xarray<double> expected = {
940+
-180.,
941+
-140.,
942+
-100.,
943+
-60.,
944+
-20.,
945+
20.,
946+
60.,
947+
100.,
948+
140.,
949+
180.,
950+
220.,
951+
260.,
952+
300.,
953+
340.,
954+
380.,
955+
420.,
956+
460.,
957+
500.,
958+
540.};
959+
xt::xarray<double> phase_deg = xt::fmod(xt::linspace<double>(0, 720, 19), 360) - 180;
960+
auto unwrapped = xt::unwrap(phase_deg, xnone(), -1, 360.0);
961+
EXPECT_TRUE(xt::allclose(expected, unwrapped));
962+
}
963+
{
964+
xt::xarray<int> expected = {2, 3, 4, 5, 6, 7, 8, 9};
965+
xt::xarray<int> phase = {2, 3, 4, 5, 2, 3, 4, 5};
966+
auto unwrapped = xt::unwrap(phase, xnone(), -1, 4);
967+
EXPECT_TRUE(xt::allclose(expected, unwrapped));
968+
}
969+
}
926970
}

0 commit comments

Comments
 (0)