-
-
Notifications
You must be signed in to change notification settings - Fork 18.6k
BUG: support corr and cov functions for custom BaseIndexer rolling windows #33804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
808b375
3a08839
8281c57
5283c71
35f0295
d0b2ca4
e24c64f
708cfe6
55d6f7d
1fc9e92
0d0c028
4bf564a
771c2bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,6 @@ | |
calculate_center_offset, | ||
calculate_min_periods, | ||
get_weighted_roll_func, | ||
validate_baseindexer_support, | ||
zsqrt, | ||
) | ||
from pandas.core.window.indexers import ( | ||
|
@@ -393,12 +392,11 @@ def _get_cython_func_type(self, func: str) -> Callable: | |
return self._get_roll_func(f"{func}_variable") | ||
return partial(self._get_roll_func(f"{func}_fixed"), win=self._get_window()) | ||
|
||
def _get_window_indexer(self, window: int, func_name: Optional[str]) -> BaseIndexer: | ||
def _get_window_indexer(self, window: int) -> BaseIndexer: | ||
""" | ||
Return an indexer class that will compute the window start and end bounds | ||
""" | ||
if isinstance(self.window, BaseIndexer): | ||
validate_baseindexer_support(func_name) | ||
return self.window | ||
if self.is_freq_type: | ||
return VariableWindowIndexer(index_array=self._on.asi8, window_size=window) | ||
|
@@ -444,7 +442,7 @@ def _apply( | |
|
||
blocks, obj = self._create_blocks() | ||
block_list = list(blocks) | ||
window_indexer = self._get_window_indexer(window, name) | ||
window_indexer = self._get_window_indexer(window) | ||
|
||
results = [] | ||
exclude: List[Scalar] = [] | ||
|
@@ -1632,20 +1630,23 @@ def quantile(self, quantile, interpolation="linear", **kwargs): | |
""" | ||
|
||
def cov(self, other=None, pairwise=None, ddof=1, **kwargs): | ||
if isinstance(self.window, BaseIndexer): | ||
validate_baseindexer_support("cov") | ||
|
||
if other is None: | ||
other = self._selected_obj | ||
# only default unset | ||
pairwise = True if pairwise is None else pairwise | ||
other = self._shallow_copy(other) | ||
|
||
# GH 16058: offset window | ||
if self.is_freq_type: | ||
window = self.win_freq | ||
# GH 32865. We leverage rolling.mean, so we pass | ||
# to the rolling constructors the data used when constructing self: | ||
# window width, frequency data, or a BaseIndexer subclass | ||
if isinstance(self.window, BaseIndexer): | ||
window = self.window | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix for |
||
else: | ||
window = self._get_window(other) | ||
# GH 16058: offset window | ||
if self.is_freq_type: | ||
window = self.win_freq | ||
else: | ||
window = self._get_window(other) | ||
|
||
def _get_cov(X, Y): | ||
# GH #12373 : rolling functions error on float32 data | ||
|
@@ -1778,15 +1779,19 @@ def _get_cov(X, Y): | |
) | ||
|
||
def corr(self, other=None, pairwise=None, **kwargs): | ||
if isinstance(self.window, BaseIndexer): | ||
validate_baseindexer_support("corr") | ||
|
||
if other is None: | ||
other = self._selected_obj | ||
# only default unset | ||
pairwise = True if pairwise is None else pairwise | ||
other = self._shallow_copy(other) | ||
window = self._get_window(other) if not self.is_freq_type else self.win_freq | ||
|
||
# GH 32865. We leverage rolling.cov and rolling.std here, so we pass | ||
# to the rolling constructors the data used when constructing self: | ||
# window width, frequency data, or a BaseIndexer subclass | ||
if isinstance(self.window, BaseIndexer): | ||
mroeschke marked this conversation as resolved.
Show resolved
Hide resolved
|
||
window = self.window | ||
else: | ||
window = self._get_window(other) if not self.is_freq_type else self.win_freq | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a comment here and on cov explaining what e are doing (generally not specific to BaseIndexer) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jreback Done. Please take a look if it's what you had in mind. |
||
|
||
def _get_corr(a, b): | ||
a = a.rolling( | ||
|
Uh oh!
There was an error while loading. Please reload this page.