From 7b72ddc8900c752f4d4848afa3e821f69725e10f Mon Sep 17 00:00:00 2001 From: John Muradeli Date: Mon, 25 Nov 2024 12:19:52 +0400 Subject: [PATCH] fix negative stride (torch); fix inferring `nv` in `icwt` --- ssqueezepy/_cwt.py | 4 ++-- ssqueezepy/utils/backend.py | 5 +++++ ssqueezepy/utils/cwt_utils.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ssqueezepy/_cwt.py b/ssqueezepy/_cwt.py index 4339b9d..6b54521 100644 --- a/ssqueezepy/_cwt.py +++ b/ssqueezepy/_cwt.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np from .utils import fft, ifft, ifftshift, FFT_GLOBAL -from .utils import WARN, adm_cwt, adm_ssq, _process_fs_and_t +from .utils import WARN, adm_cwt, adm_ssq, _process_fs_and_t, is_array_or_tensor from .utils import padsignal, process_scales, logscale_transition_idx from .utils import backend as S from .utils.backend import Q @@ -404,7 +404,7 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, #### Prepare for inversion ############################################### *_, na, n = Wx.shape x_len = x_len or n - if not isinstance(scales, np.ndarray) and nv is None: + if not is_array_or_tensor(scales) and nv is None: nv = 32 # must match forward's; default to `cwt`'s wavelet = _process_gmw_wavelet(wavelet, l1_norm) diff --git a/ssqueezepy/utils/backend.py b/ssqueezepy/utils/backend.py index 64f1972..1a3f5ae 100644 --- a/ssqueezepy/utils/backend.py +++ b/ssqueezepy/utils/backend.py @@ -46,6 +46,11 @@ def is_tensor(*args, mode='all'): return cond(isinstance(x, torch.Tensor) for x in args) +def is_array_or_tensor(*args, mode='all'): + cond = all if mode == 'all' else any + return cond(isinstance(x, (torch.Tensor, np.ndarray)) for x in args) + + def is_dtype(x, str_dtype): return (str_dtype in str(x.dtype) if isinstance(str_dtype, str) else any(sd in str(x.dtype) for sd in str_dtype)) diff --git a/ssqueezepy/utils/cwt_utils.py b/ssqueezepy/utils/cwt_utils.py index 7e2dfdf..d005c0d 100644 --- a/ssqueezepy/utils/cwt_utils.py +++ b/ssqueezepy/utils/cwt_utils.py @@ -588,7 +588,7 @@ def integrate_analytic(int_fn, nowarn=False): Integrates near zero separately in log space (useful for e.g. 1/x). """ def _est_arr(mxlim, N): - t = np.linspace(mxlim, .1, N, endpoint=False)[::-1] + t = np.linspace(mxlim, .1, N, endpoint=False)[::-1].copy() arr = int_fn(t) max_idx = np.argmax(arr)