Skip to content

Search with polars#767

Open
charles-turner-1 wants to merge 20 commits intomainfrom
search-with-polars
Open

Search with polars#767
charles-turner-1 wants to merge 20 commits intomainfrom
search-with-polars

Conversation

@charles-turner-1
Copy link
Collaborator

@charles-turner-1 charles-turner-1 commented Jan 6, 2026

Change Summary

  • Add a _search.pl_search function using polars (as lazily as I can), not pandas, if we don't have a pd.DataFrame instantiated. If we have one, we use the old _search.search function.
  • esm_datastore.search()still triggers the creation of a pandas dataframe, but well after
  • Remove the pl_df attribute from FramesModel - it just adds memory overhead & I don't think we're benefiting from keeping it around.
  • Update all tests to ensure that the new search is the same as the old search.
  • Explicitly add a NotImplementedError if users try to regex search over columns_with_iterables (it previously wasn't implemented but didn't raise an error, see Regular expressions when columns_with_iterables #679). I've put some code beyond the error which I think should get that to work, but I think it would make more sense to split that out to a separate PR.

@aulemahal are you able to rerun that profiling code you posted in #753 against this branch? I've done some profiling myself (see below), but the numbers I'm getting out are surprisingly different. With that said, it looks like the memory overhead of the search is down substantially on where we were. It also looks like we're not realising the full dataframe into memory so we should be down on previous memory usage (See allocations on line 17).

Things I haven't changed

  • search_apply_require_all_on still needs a pandas dataframe, so search still triggers the creation of a pandas dataframe. We can probably get away without doing this with some more work.

Memory Profiling

from memory_profiler import profile

from intake_esm.core import esm_datastore


@profile
def main():
    _ = 1
    cat = esm_datastore('/Users/u1166368/scratch/simulation.json')
    print("\nSearching for variable='tasmin'...\n")
    scat = cat.search(variable='tasmin')

    print('\n\nscat.df info:')
    print(scat.df.info())

    print('\n\ncat.df info:')
    print(cat.df.info())

    _srcs = scat.unique()


if __name__ == '__main__':
    main()

Gives (head of this branch)

python profile_intake_esm_pascal.py        2548ms  Thu Jan  8 16:12:04 2026

Searching for variable='tasmin'...

Filename: /Users/u1166368/catalog/intake-esm/intake_esm/_search.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    77    296.8 MiB    296.8 MiB           1   @profile
    78                                         def pl_search(
    79                                             *,
    80                                             lf: pl.LazyFrame,
    81                                             query: dict[str, typing.Any],
    82                                             columns_with_iterables: set,
    83                                             iterable_dtypes: dict[str, type] | None = None,
    84                                         ) -> pd.DataFrame:
    85                                             """
    86                                             Search for entries in the catalog.
    87                                         
    88                                             Parameters
    89                                             ----------
    90                                             df: :py:class:`~pandas.DataFrame`
    91                                                 A dataframe to search
    92                                             query: dict
    93                                                 A dictionary of query parameters to execute against the dataframe
    94                                             columns_with_iterables: list
    95                                                 Columns in the dataframe that have iterables
    96                                             iterable_dtypes: dict, optional
    97                                                 A dictionary mapping column names to their iterable dtypes. If not provided,
    98                                                 defaults to all tuple
    99                                         
   100                                             Returns
   101                                             -------
   102                                             dataframe: :py:class:`~pandas.DataFrame`
   103                                                 A new dataframe with the entries satisfying the query criteria.
   104                                         
   105                                             """
   106                                         
   107    296.8 MiB      0.0 MiB           1       if not query:
   108                                                 return lf.filter(pl.lit(False)).collect().to_pandas()
   109                                         
   110    336.0 MiB     39.2 MiB           1       full_schema = lf.head(1).collect().schema
   111                                         
   112    336.2 MiB      0.1 MiB           2       lf = lf.with_columns(
   113    336.2 MiB      0.1 MiB          24           [pl.col(colname).cast(full_schema[colname]) for colname in full_schema.keys()]
   114                                             )
   115                                         
   116    336.2 MiB      0.0 MiB           1       if isinstance(columns_with_iterables, str):
   117                                                 columns_with_iterables = [columns_with_iterables]
   118                                         
   119    336.2 MiB      0.0 MiB           1       iterable_dtypes = iterable_dtypes or {colname: tuple for colname in columns_with_iterables}
   120                                         
   121    336.2 MiB      0.0 MiB           1       for colname, dtype in iterable_dtypes.items():
   122                                                 if dtype == np.ndarray:
   123                                                     iterable_dtypes[colname] = tuple
   124                                         
   125    336.2 MiB      0.0 MiB           2       query_non_iterable = {
   126    336.2 MiB      0.0 MiB           3           key: val for key, val in query.items() if key not in columns_with_iterables
   127                                             }
   128                                         
   129    337.2 MiB      0.0 MiB           2       for colname, subquery in query_non_iterable.items():
   130    336.2 MiB      0.0 MiB           2           subquery = [None if pd.isna(subq) else subq for subq in subquery]
   131    336.2 MiB      0.0 MiB           1           if is_pattern(subquery):
   132                                                     case_insensitive = [
   133                                                         bool(q.flags & re.IGNORECASE) if isinstance(q, re.Pattern) else False
   134                                                         for q in subquery
   135                                                     ]
   136                                                     # Prepend (?i) to patterns for case insensitive matching if needed
   137                                                     subquery = [q.pattern if isinstance(q, re.Pattern) else q for q in subquery]
   138                                                     subquery = [f'(?i){q}' if ci else q for q, ci in zip(subquery, case_insensitive)]
   139                                         
   140                                                     lf = lf.filter(pl.col(colname).str.contains('|'.join(subquery), literal=False))
   141                                                 else:
   142    337.2 MiB      1.0 MiB           1               lf = lf.filter(pl.col(colname).is_in(subquery, nulls_equal=True))
   143                                         
   144    337.2 MiB      0.0 MiB           2       query_iterable = {key: val for key, val in query.items() if key in columns_with_iterables}
   145    337.2 MiB      0.0 MiB           1       for colname, subquery in query_iterable.items():
   146                                                 if is_pattern(subquery):
   147                                                     raise NotImplementedError(
   148                                                         'Pattern matching within iterable columns is not implemented yet.'
   149                                                     )
   150                                                     case_insensitive = [
   151                                                         bool(q.flags & re.IGNORECASE) if isinstance(q, re.Pattern) else False
   152                                                         for q in subquery
   153                                                     ]
   154                                                     # Prepend (?i) to patterns for case insensitive matching if needed
   155                                                     subquery = [q.pattern if isinstance(q, re.Pattern) else q for q in subquery]
   156                                                     subquery = [f'(?i){q}' if ci else q for q, ci in zip(subquery, case_insensitive)]
   157                                         
   158                                                     lf = lf.filter(
   159                                                         pl.col(colname)
   160                                                         .list.eval(pl.element().str.contains('|'.join(subquery), literal=False))
   161                                                         .any()
   162                                                     )
   163                                                 else:
   164                                                     lf = lf.filter(
   165                                                         pl.col(colname)
   166                                                         .list.eval(pl.element().is_in(subquery, nulls_equal=True).any())
   167                                                         .explode()
   168                                                     )
   169                                         
   170    475.0 MiB    137.8 MiB           1       df = lf.collect().to_pandas()
   171                                         
   172    475.0 MiB      0.0 MiB           1       for colname, dtype in iterable_dtypes.items():
   173                                                 df[colname] = df[colname].apply(dtype)
   174                                         
   175    475.0 MiB      0.0 MiB           1       return df


Filename: /Users/u1166368/catalog/intake-esm/intake_esm/cat.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   409    296.7 MiB    296.7 MiB           1       @profile
   410                                             def search(
   411                                                 self,
   412                                                 *,
   413                                                 query: QueryModel | dict[str, typing.Any],
   414                                                 require_all_on: str | list[str] | None = None,
   415                                             ) -> pd.DataFrame:
   416                                                 """
   417                                                 Search for entries in the catalog.
   418                                         
   419                                                 Parameters
   420                                                 ----------
   421                                                 query: dict, optional
   422                                                     A dictionary of query parameters to execute against the dataframe.
   423                                                 require_all_on : list, str, optional
   424                                                     A dataframe column or a list of dataframe columns across
   425                                                     which all entries must satisfy the query criteria.
   426                                                     If None, return entries that fulfill any of the criteria specified
   427                                                     in the query, by default None.
   428                                         
   429                                                 Returns
   430                                                 -------
   431                                                 catalog: ESMCatalogModel
   432                                                     A new catalog with the entries satisfying the query criteria.
   433                                         
   434                                                 """
   435                                         
   436                                                 # The way we get columns with iterables here looks a bit roundabout, but it
   437                                                 # minimizes memory overhead.
   438    296.8 MiB      0.1 MiB           1           cols = list(self.lf.collect_schema().keys())
   439    296.8 MiB      0.0 MiB          24           col_subset = {col for col, dtype in self.lf.collect_schema().items() if dtype == pl.Unknown}
   440                                         
   441    296.8 MiB      0.0 MiB           2           columns_with_iterables = {
   442                                                     col
   443    296.8 MiB      0.0 MiB           2               for col, dtype in self._frames.lf.head(1).select(col_subset).collect().schema.items()
   444                                                     if dtype == pl.List
   445                                                 }
   446                                         
   447    296.8 MiB      0.0 MiB           1           _query = (
   448                                                     query
   449    296.8 MiB      0.0 MiB           1               if isinstance(query, QueryModel)
   450    296.8 MiB      0.0 MiB           1               else QueryModel(query=query, require_all_on=require_all_on, columns=cols)
   451                                                 )
   452                                         
   453    296.8 MiB      0.0 MiB           1           if (_df := self._frames.df) is not None:
   454                                                     iterable_dtypes = {
   455                                                         colname: type(_df[colname].iloc[0]) for colname in columns_with_iterables
   456                                                     }
   457                                                 else:
   458    296.8 MiB      0.0 MiB           1               iterable_dtypes = None
   459                                         
   460    475.0 MiB    178.2 MiB           2           results = pl_search(
   461    296.8 MiB      0.0 MiB           1               lf=self.lf,
   462    296.8 MiB      0.0 MiB           1               query=_query.query,
   463    296.8 MiB      0.0 MiB           1               columns_with_iterables=columns_with_iterables,
   464    296.8 MiB      0.0 MiB           1               iterable_dtypes=iterable_dtypes,
   465                                                 )
   466                                         
   467    475.0 MiB      0.0 MiB           1           if _query.require_all_on is not None and not results.empty:
   468                                                     results = search_apply_require_all_on(
   469                                                         df=results,
   470                                                         query=_query.query,
   471                                                         require_all_on=_query.require_all_on,
   472                                                         columns_with_iterables=columns_with_iterables,
   473                                                     )
   474    475.0 MiB      0.0 MiB           1           return results




scat.df info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25600 entries, 0 to 25599
Data columns (total 23 columns):
 #   Column                   Non-Null Count  Dtype 
---  ------                   --------------  ----- 
 0   id                       25600 non-null  object
 1   type                     25600 non-null  object
 2   processing_level         25600 non-null  object
 3   bias_adjust_institution  388 non-null    object
 4   bias_adjust_project      473 non-null    object
 5   bias_adjust_reference    0 non-null      object
 6   mip_era                  25600 non-null  object
 7   activity                 25600 non-null  object
 8   driving_model            12760 non-null  object
 9   driving_member           12760 non-null  object
 10  institution              25600 non-null  object
 11  source                   25600 non-null  object
 12  experiment               25600 non-null  object
 13  member                   13237 non-null  object
 14  xrfreq                   25600 non-null  object
 15  frequency                25600 non-null  object
 16  variable                 25600 non-null  object
 17  domain                   25600 non-null  object
 18  date_start               25600 non-null  object
 19  date_end                 25600 non-null  object
 20  version                  7240 non-null   object
 21  format                   25600 non-null  object
 22  path                     25600 non-null  object
dtypes: object(23)
memory usage: 4.5+ MB
None


cat.df info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 159277 entries, 0 to 159276
Data columns (total 23 columns):
 #   Column                   Non-Null Count   Dtype                
---  ------                   --------------   -----                
 0   id                       159277 non-null  large_string[pyarrow]
 1   type                     159277 non-null  large_string[pyarrow]
 2   processing_level         159277 non-null  large_string[pyarrow]
 3   bias_adjust_institution  1354 non-null    large_string[pyarrow]
 4   bias_adjust_project      1893 non-null    large_string[pyarrow]
 5   bias_adjust_reference    0 non-null       large_string[pyarrow]
 6   mip_era                  159277 non-null  large_string[pyarrow]
 7   activity                 159277 non-null  large_string[pyarrow]
 8   driving_model            109569 non-null  large_string[pyarrow]
 9   driving_member           109569 non-null  large_string[pyarrow]
 10  institution              159277 non-null  large_string[pyarrow]
 11  source                   159277 non-null  large_string[pyarrow]
 12  experiment               159277 non-null  large_string[pyarrow]
 13  member                   52497 non-null   large_string[pyarrow]
 14  xrfreq                   159277 non-null  large_string[pyarrow]
 15  frequency                159277 non-null  large_string[pyarrow]
 16  variable                 159277 non-null  large_string[pyarrow]
 17  domain                   159277 non-null  large_string[pyarrow]
 18  date_start               158367 non-null  large_string[pyarrow]
 19  date_end                 158367 non-null  large_string[pyarrow]
 20  version                  86294 non-null   large_string[pyarrow]
 21  format                   159277 non-null  large_string[pyarrow]
 22  path                     159277 non-null  large_string[pyarrow]
dtypes: large_string[pyarrow](23)
memory usage: 84.2 MB
None
Filename: /Users/u1166368/catalog/intake-esm/profile_intake_esm_pascal.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
     6    230.7 MiB    230.7 MiB           1   @profile
     7                                         def main():
     8    230.7 MiB      0.0 MiB           1       _ = 1
     9    296.7 MiB     66.0 MiB           1       cat = esm_datastore('/Users/u1166368/scratch/simulation.json')
    10    296.7 MiB      0.0 MiB           1       print("\nSearching for variable='tasmin'...\n")
    11    481.2 MiB    184.5 MiB           1       scat = cat.search(variable='tasmin')
    12                                         
    13    481.2 MiB      0.0 MiB           1       print('\n\nscat.df info:')
    14    482.7 MiB      1.5 MiB           1       print(scat.df.info())
    15                                         
    16    482.7 MiB      0.0 MiB           1       print('\n\ncat.df info:')
    17    689.1 MiB    206.5 MiB           1       print(cat.df.info())
    18                                         
    19    727.6 MiB     38.5 MiB           1       _srcs = scat.unique()

With intake-esm=2025.2.3:

Filename: /Users/u1166368/scratch/pascal_profile/profile_intake_esm_pascal.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
     6    186.6 MiB    186.6 MiB           1   @profile
     7                                         def main():
     8    186.6 MiB      0.0 MiB           1       _ = 1
     9    397.8 MiB    211.2 MiB           1       cat = esm_datastore('/Users/u1166368/scratch/simulation.json')
    10    397.8 MiB      0.0 MiB           1       print("\nSearching for variable='tasmin'...\n")
    11    403.5 MiB      5.6 MiB           1       scat = cat.search(variable='tasmin')
    12                                         
    13    403.5 MiB      0.0 MiB           1       print('\n\nscat.df info:')
    14    404.2 MiB      0.8 MiB           1       print(scat.df.info())
    15                                         
    16    404.2 MiB      0.0 MiB           1       print('\n\ncat.df info:')
    17    404.3 MiB      0.1 MiB           1       print(cat.df.info())
    18                                         
    19    404.4 MiB      0.1 MiB           1       _srcs = scat.unique()

It looks like there might be some other memory management issues we'll need to clean up too, but presently I think that the initialisation & search memory usage should now be back in the same ballpark.


Search performance

from intake_esm.core import esm_datastore
import timeit


def main():
    _ = 1
    cat = esm_datastore("/Users/u1166368/scratch/simulation.json")
    print("\nSearching for variable='tasmin'...\n")

    elapsed = timeit.timeit(lambda: cat.search(variable="tasmin"), number=10)

    print(f"Average time per search: {elapsed / 10:.4f} seconds")


if __name__ == "__main__":
    main()
  • v2025.2.3 : Average time per search: 0.0060 seconds
  • This branch: Average time per search: 0.0340 seconds

This is obviously quite a bit slower (~5x), which is not great. Rolling together opening & searching the datastore

from intake_esm.core import esm_datastore
import timeit


def m():
    cat = esm_datastore("/Users/u1166368/scratch/simulation.json")
    cat.search(variable="tasmin")


def main():
    elapsed = timeit.timeit(lambda: m(), number=10)
    print(f"Average time per open and search: {elapsed / 10:.4f} seconds")


if __name__ == "__main__":
    main()
  • This branch: Average time per open and search: 0.0835 seconds
  • v2025.2.3: Average time per open and search: 0.4668 seconds

I suspect most of that search time increase is the instantiation of the pandas dataframe - I'm hoping to continue to defer this further, but this memory issue needs fixing first.

Related issue number

#711
#753

Checklist

  • Unit tests for the changes exist
  • Tests pass on CI
  • Documentation reflects the changes where applicable

Copilot AI and others added 9 commits January 6, 2026 13:28
* Initial plan

* Add memory_profiler and pympler to CI environment dependencies

Co-authored-by: charles-turner-1 <52199577+charles-turner-1@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: charles-turner-1 <52199577+charles-turner-1@users.noreply.github.com>
@charles-turner-1 charles-turner-1 marked this pull request as ready for review January 8, 2026 08:40
@aulemahal
Copy link
Contributor

I ran the same thing as last time, which looks like the same as you used. However, you seem to get lower RAM increments.... and I can't reproduce the same values s last time... I guess I did something wrong, but can't see what.

Anyway, here's the profile with this branch:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    34    232.0 MiB    232.0 MiB           1   @profile
    35                                         def func(kwargs):
    36    451.8 MiB    219.9 MiB           1   	cat = intake.open_esm_datastore('simulation.json', read_csv_kwargs=kwargs)
    37    743.9 MiB    292.1 MiB           1   	scat = cat.search(variable='tasmin')
    38    785.0 MiB     41.1 MiB           1   	srcs = scat.unique()
    39    785.0 MiB      0.0 MiB           1   	print(srcs.id)

And with main:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    34    223.4 MiB    223.4 MiB           1   @profile
    35                                         def func(kwargs):
    36    383.9 MiB    160.6 MiB           1   	cat = intake.open_esm_datastore('simulation.json', read_csv_kwargs=kwargs)
    37    889.1 MiB    505.2 MiB           1   	scat = cat.search(variable='tasmin')
    38    913.9 MiB     24.8 MiB           1   	srcs = scat.unique()
    39    913.9 MiB      0.0 MiB           1   	print(srcs.id)

The values are not really stable, I see them moving around when I repeat the test. With this uncertainty in mind, it seems to be that this branch moves the increment to the search, but the overall RAM usage is similar ? Maybe a slight decrease.

And with v2025.2.3 (keeping the same env!):

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    34    228.0 MiB    228.0 MiB           1   @profile
    35                                         def func(kwargs):
    36    341.7 MiB    113.7 MiB           1   	cat = intake.open_esm_datastore('simulation.json', read_csv_kwargs=kwargs)
    37    341.7 MiB      0.0 MiB           1   	scat = cat.search(variable='tasmin')
    38    341.7 MiB      0.0 MiB           1   	srcs = scat.unique()
    39    341.7 MiB      0.0 MiB           1   	print(srcs.id)

Thanks for your work! Sorry that I can't help much right now.

@charles-turner-1
Copy link
Collaborator Author

Weird - those numbers are very different to mine... I'll keep digging!

@charles-turner-1 charles-turner-1 marked this pull request as draft January 8, 2026 23:33
@charles-turner-1
Copy link
Collaborator Author

Okay, with the latest commits:

  • Average time per search: 0.0548 seconds
  • Average time per open and search: 0.0723 seconds

Newest changes TLDR;

  • If we don't have a pandas dataframe instantiated, we do the searching with polars, and then instantiate it when asked for
  • If we do have a pandas dataframe, we do the searching using it - the assumption being that instantiating the pandas dataframe will typically be much more expensive than the search.

Memory usage is basically unchanged by this.

@charles-turner-1
Copy link
Collaborator Author

charles-turner-1 commented Jan 9, 2026

Following merging #771 into this, total memory usage is now down to 200MiB when using polars=1.34:

python profile_intake_esm_pascal.py

Searching for variable='tasmin'...

Filename: /Users/u1166368/catalog/intake-esm/profile_intake_esm_pascal.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
     6    231.1 MiB    231.1 MiB           1   @profile
     7                                         def main():
     8    231.1 MiB      0.0 MiB           1       _ = 1
     9    237.9 MiB      6.8 MiB           1       cat = esm_datastore('/Users/u1166368/scratch/simulation.json')
    10    237.9 MiB      0.0 MiB           1       print("\nSearching for variable='tasmin'...\n")
    11    444.6 MiB    206.7 MiB           1       scat = cat.search(variable='tasmin')
    12
    13                                             # print('\n\nscat.df info:')
    14                                             # print(scat.df.info())
    15
    16                                             # print('\n\ncat.df info:')
    17                                             # print(cat.df.info())
    18
    19                                             # _srcs = scat.unique()

Tad frustrating it doesn't seem like we'll be able to get down much further without deferring the creation of a pandas dataframe aggressively, but I think this is probably acceptable?

Average time per open and search: 0.0846 seconds
Average time per search: 0.0616 seconds

@charles-turner-1 charles-turner-1 marked this pull request as ready for review January 12, 2026 01:43
@charles-turner-1
Copy link
Collaborator Author

I've started looking at transforming the esmcat.search functionality to polars, and it's looking rather complicated to do cleanly... I think leaving that for a separate PR is going to be the way to go, unfortunately (I'd love to do it all in one hit 😅 ).

@aulemahal does this look like the memory consumption is down enough to work well for you guys now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants