Skip to content

Commit

Permalink
Lazy load pandas in TFDS to speed up import time.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555472530
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed Aug 10, 2023
1 parent 45a45ed commit 57ee730
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 59 deletions.
85 changes: 39 additions & 46 deletions tensorflow_datasets/core/as_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import dataclasses
import typing
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
from tensorflow_datasets.core import dataset_info
Expand All @@ -28,15 +28,9 @@
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core.utils import py_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf

try:
import pandas # pylint: disable=g-import-not-at-top

DataFrame = pandas.DataFrame
except ImportError:
DataFrame = object

# Should be `pandas.io.formats.style.Styler`, but is a costly import
Styler = Any

Expand Down Expand Up @@ -127,43 +121,6 @@ def _get_feature(
return feature, sequence_rank


class StyledDataFrame(DataFrame):
"""`pandas.DataFrame` displayed as `pandas.io.formats.style.Styler`.
`StyledDataFrame` is a `pandas.DataFrame` with better Jupyter notebook
representation. Contrary to regular `pandas.DataFrame`, the `style` is
attached to the `pandas.DataFrame`.
```
df = StyledDataFrame(...)
df.current_style.apply(...) # Configure the style
df # The data-frame is displayed using ` pandas.io.formats.style.Styler`
```
"""

# StyledDataFrame could be improved such as the style is forwarded when
# selecting sub-data frames.

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # pytype: disable=wrong-arg-count # re-none
# Use name-mangling for forward-compatibility in case pandas
# adds a `_styler` attribute in the future.
self.__styler: Optional[Styler] = None

@property
def current_style(self) -> Styler:
"""Like `pandas.DataFrame.style`, but attach the style to the DataFrame."""
if self.__styler is None:
self.__styler = super().style # pytype: disable=attribute-error # re-none
return self.__styler

def _repr_html_(self) -> str:
# See base class for doc
if self.__styler is None:
return super()._repr_html_() # pytype: disable=attribute-error # re-none
return self.__styler._repr_html_() # pylint: disable=protected-access


def _make_columns(
specs: TreeDict[tf.TypeSpec],
ds_info: Optional[dataset_info.DatasetInfo],
Expand All @@ -187,7 +144,7 @@ def _make_row_dict(
def as_dataframe(
ds: tf.data.Dataset,
ds_info: Optional[dataset_info.DatasetInfo] = None,
) -> StyledDataFrame:
) -> pd.DataFrame:
"""Convert the dataset into a pandas dataframe.
Warning: The dataframe will be loaded entirely in memory, you may
Expand All @@ -211,6 +168,42 @@ def as_dataframe(
# Raise a clean error message if panda isn't installed.
lazy_imports_lib.lazy_imports.pandas # pylint: disable=pointless-statement

class StyledDataFrame(pd.DataFrame):
"""`pandas.DataFrame` displayed as `pandas.io.formats.style.Styler`.
`StyledDataFrame` is a `pandas.DataFrame` with better Jupyter notebook
representation. Contrary to regular `pandas.DataFrame`, the `style` is
attached to the `pandas.DataFrame`.
```
df = StyledDataFrame(...)
df.current_style.apply(...) # Configure the style
df # The data-frame is displayed using ` pandas.io.formats.style.Styler`
```
"""

# StyledDataFrame could be improved such as the style is forwarded when
# selecting sub-data frames.

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # pytype: disable=wrong-arg-count # re-none
# Use name-mangling for forward-compatibility in case pandas
# adds a `_styler` attribute in the future.
self.__styler: Optional[Styler] = None

@property
def current_style(self) -> Styler:
"""Like `pandas.DataFrame.style`, but attach the style to the DataFrame."""
if self.__styler is None:
self.__styler = super().style # pytype: disable=attribute-error # re-none
return self.__styler

def _repr_html_(self) -> Union[None, str]:
# See base class for doc
if self.__styler is None:
return super()._repr_html_() # pytype: disable=attribute-error # re-none
return self.__styler._repr_html_() # pylint: disable=protected-access

# Pack `as_supervised=True` datasets
if ds_info:
ds = dataset_info.pack_as_supervised_ds(ds, ds_info)
Expand Down
6 changes: 1 addition & 5 deletions tensorflow_datasets/core/utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@

from absl import logging
from tensorflow_datasets.core.utils import tqdm_utils

try:
import pandas as pd # pylint: disable=g-import-not-at-top
except ImportError:
pd = Any
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd

# pylint: disable=logging-format-interpolation

Expand Down
8 changes: 5 additions & 3 deletions tensorflow_datasets/core/utils/lazy_imports_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,16 @@ def array_record_error_callback(**kwargs):
print("***************************************************************\n\n")


with lazy_imports():
import pandas # pylint: disable=g-import-not-at-top,unused-import


with lazy_imports(
error_callback=tf_error_callback, success_callback=ensure_tf_version
):
import tensorflow as tf # pylint: disable=g-import-not-at-top,unused-import # pytype: disable=import-error
import tensorflow # pylint: disable=g-import-not-at-top,unused-import # pytype: disable=import-error


with lazy_imports(error_callback=array_record_error_callback):
from array_record.python import array_record_data_source # pylint: disable=g-import-not-at-top,unused-import
from array_record.python import array_record_module # pylint: disable=g-import-not-at-top,unused-import

tensorflow = tf
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import collections
import itertools
import random
from typing import Iterable, Dict, Text, List
from typing import Dict, Iterable, List, Text

from absl import app
from absl import flags
from etils import epath
import numpy as np
import pandas as pd
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.utils import resource_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd

# Command-line arguments.
flags.DEFINE_string('save_path', None, 'Path to save generated data to.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Cifar100N(tfds.core.GeneratorBasedBuilder):
```
import numpy as np
import pandas as pd
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
human_labels_np_path = '<local_path>/CIFAR-100_human_ordered.npy'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
from etils import epath
import numpy as np
import pandas as pd
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
import tensorflow_datasets.public_api as tfds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Cifar10N(tfds.core.GeneratorBasedBuilder):
```
import numpy as np
import pandas as pd
from tensorflow_datasets.core.utils.lazy_imports_utils import pandas as pd
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
human_labels_np_path = '<local_path>/CIFAR-10_human_ordered.npy'
Expand Down
1 change: 1 addition & 0 deletions tensorflow_datasets/import_without_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _generate_examples(self):

self.assertNotIn('tensorflow', sys.modules)
self.assertNotIn('array_record', sys.modules)
self.assertNotIn('pandas', sys.modules)

data_dir = '/tmp/import_without_tf'
builder = DummyDataset(data_dir=data_dir)
Expand Down

0 comments on commit 57ee730

Please sign in to comment.