Skip to content

Commit

Permalink
First round of updates (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Oct 21, 2023
1 parent 8146755 commit 4c5d3a9
Show file tree
Hide file tree
Showing 19 changed files with 140 additions and 87 deletions.
4 changes: 1 addition & 3 deletions benchmark/data_frame_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@
scale=args.scale, idx=args.idx)
dataset.materialize()
dataset = dataset.shuffle()
train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
train_dataset, val_dataset, test_dataset = dataset.split()

train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
Expand Down
2 changes: 1 addition & 1 deletion docs/source/get_started/installation.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Installation
============

PyTorch Frame is available for Python 3.8 to Python 3.11 on Linux, Windows and macOS.
:pyf:`PyTorch Frame` is available for :python:`Python 3.8` to :python:`Python 3.11` on Linux, Windows and macOS.

Installation via PyPI
---------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/get_started/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Introduction by Example
=======================

:pyf:`PyTorch Frame` is a tabular deep learning extension library for :pytorch:`null` `PyTorch <https://pytorch.org>`_.
Modern data is stored in a table format with heterogeneous columns each with its own semantic type, e.g., numerical (e.g., age, price), categorical (e.g., gender, product type), time, text (e.g., descriptions, comments), images, etc.
Modern data is stored in a table format with heterogeneous columns each with its own semantic type, *e.g.*, numerical (such as age or price), categorical (such as gender or product type), time, text (such as descriptions or comments), images, etc.
The goal of :pyf:`PyTorch Frame` is to build a deep learning framework to perform effective machine learning on such complex and diverse data.

Many recent tabular models follow the modular design of :obj:`FeatureEncoder`, :obj:`TableConv`, and :obj:`Decoder`.
Expand Down Expand Up @@ -100,7 +100,7 @@ will materialize the dataset and save the materialized :class:`~torch_frame.Tens
.. note::
Note that materialization does minimal processing of the original features, e.g., no normalization and missing value handling are performed.
:pyf:`PyTorch Frame` converts missing values in categorical :class:`torch_frame.stype` to `-1` and missing values in numerical :class:`torch_frame.stype` to `NaN`.
PyTorch Frame converts missing values in categorical :class:`torch_frame.stype` to `-1` and missing values in numerical :class:`torch_frame.stype` to `NaN`.
We expect `NaN`/missing-value handling and normalization to be handled by the model side via :class:`torch_frame.nn.encoder.StypeEncoder`.
The :class:`~torch_frame.TensorFrame` object has :class:`torch.Tensor` at its core; therefore, it's friendly for training and inference with PyTorch. In :pyf:`PyTorch Frame`, we build data loaders and models around :class:`TensorFrame`, benefitting from all the efficiency and flexibility from PyTorch.
Expand Down
6 changes: 3 additions & 3 deletions docs/source/get_started/modular_design.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ and :class:`~torch_frame.nn.encoder.LinearEncoder` for encoding `stype.numerical
)
There are other encoders implemented as well such as :class:`~torch_frame.nn.encoder.LinearBucketEncoder` and :class:`~torch_frame.nn.encoder.ExcelFormerEncoder` for `stype.numerical` columns.
See :ref:`_torch_frame_nn` for the full list of built-in encoders.
See :py:mod:`torch_frame.nn` for the full list of built-in encoders.

You can also implement your custom encoder for a given `stype` by inheriting :class:`~torch_frame.nn.encoder.StypeEncoder`.

Expand Down Expand Up @@ -103,7 +103,7 @@ Initializing and calling it is straightforward.
conv = SelfAttentionConv(32)
x = conv(x)
See :ref:`_torch_frame_nn` for the full list of built-in convolution layers.
See :py:mod:`torch_frame.nn` for the full list of built-in convolution layers.


3. :class:`Decoder`
Expand Down Expand Up @@ -133,4 +133,4 @@ Below is a simple example of a :class:`~torch_frame.nn.decoder.Decoder` that mea
# [batch_size, out_channels]
return self.lin(out)
See :ref:`_torch_frame_nn` for the full list of built-in decoders.
See :py:mod:`torch_frame.nn` for the full list of built-in decoders.
3 changes: 0 additions & 3 deletions docs/source/modules/nn.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
.. _torch_frame_nn:

torch_frame.nn
==============

Expand Down Expand Up @@ -32,7 +30,6 @@ torch_frame.nn.conv
:toctree: ../generated
:template: autosummary/class.rst


{% for name in torch_frame.nn.conv.classes %}
{{ name }}
{% endfor %}
Expand Down
4 changes: 2 additions & 2 deletions docs/source/modules/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ Let's look an example, where we apply `CatToNumTransform <https://dl.acm.org/doi
dataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
transform = CatToNumTransform()
train_dataset = dataset.get_split_dataset('train')
train_dataset = dataset.get_split('train')
train_dataset.tensor_frame.col_names_dict[stype.categorical]
>>> ['C_feature_0', 'C_feature_1', 'C_feature_2', 'C_feature_3', 'C_feature_4', 'C_feature_5', 'C_feature_6', 'C_feature_7']
test_dataset = dataset.get_split_dataset('test')
test_dataset = dataset.get_split('test')
transform.fit(train_dataset.tensor_frame, dataset.col_stats)
transformed_col_stats = transform.transformed_stats
Expand Down
4 changes: 1 addition & 3 deletions examples/excelformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
args.dataset)
dataset = Yandex(root=path, name=args.dataset)
dataset.materialize()
train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
train_dataset, val_dataset, test_dataset = dataset.split()
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
Expand Down
9 changes: 3 additions & 6 deletions examples/ft_transformer_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,9 @@ def __call__(self, sentences: List[str]) -> Tensor:

is_classification = dataset.task_type.is_classification

train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
if val_dataset.tensor_frame.num_rows == 0:
train_dataset = dataset.get_split_dataset('train')[:0.9]
val_dataset = dataset.get_split_dataset('train')[0.9:]
train_dataset, val_dataset, test_dataset = dataset.split()
if len(val_dataset) == 0:
train_dataset, val_dataset = train_dataset[:0.9], train_dataset[0.9:]

# Set up data loaders
train_tensor_frame = train_dataset.tensor_frame
Expand Down
4 changes: 1 addition & 3 deletions examples/revisiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@
dataset.materialize()
is_classification = dataset.task_type.is_classification

train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
train_dataset, val_dataset, test_dataset = dataset.split()

# Set up data loaders
train_tensor_frame = train_dataset.tensor_frame
Expand Down
4 changes: 1 addition & 3 deletions examples/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@
assert dataset.task_type.is_classification

# Get pre-defined split
train_dataset = dataset.get_split_dataset('train')
val_dataset = dataset.get_split_dataset('val')
test_dataset = dataset.get_split_dataset('test')
train_dataset, val_dataset, test_dataset = dataset.split()

train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
Expand Down
2 changes: 1 addition & 1 deletion test/transforms/test_mutual_information_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_mutual_information_sort(with_nan):
dataset.materialize()

tensor_frame: TensorFrame = dataset.tensor_frame
train_dataset = dataset.get_split_dataset('train')
train_dataset = dataset.get_split('train')
transform = MutualInformationSort(task_type)
transform.fit(train_dataset.tensor_frame, train_dataset.col_stats)
out = transform(tensor_frame)
Expand Down
35 changes: 22 additions & 13 deletions torch_frame/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _requires_post_materialization(self, *args, **kwargs):


class DataFrameToTensorFrameConverter:
r"""DataFrame to TensorFrame converter.
r"""A data frame to :class:`TensorFrame` converter.
Args:
col_to_stype (Dict[str, :class:`torch_frame.stype`]):
Expand Down Expand Up @@ -148,7 +148,7 @@ def __call__(


class Dataset(ABC):
r"""Base class for creating tabular datasets.
r"""A base class for creating tabular datasets.
Args:
df (DataFrame): The tabular data frame.
Expand All @@ -160,9 +160,8 @@ class Dataset(ABC):
information. The column should only contain :obj:`0`, :obj:`1`, or
:obj:`2`. (default: :obj:`None`).
text_embedder_cfg (TextEmbedderConfig, optional): A text embedder
config specifying :obj:`text_embedder` that maps sentences into
PyTorch embeddings and :obj:`batch_size` that specifies the
mini-batch size for :obj:`text_embedder` (default: :obj:`None`)
configuration that specifies the text embedder to map text columns
into :pytorch:`PyTorch` embeddings. (default: :obj:`None`)
"""
def __init__(
self,
Expand Down Expand Up @@ -263,7 +262,7 @@ def task_type(self) -> TaskType:

@property
def num_rows(self):
r"""Number of rows."""
r"""The number of rows of the dataset."""
return len(self.df)

@property
Expand Down Expand Up @@ -415,23 +414,33 @@ def col_select(self, cols: ColumnSelectType) -> 'Dataset':

return dataset

def get_split_dataset(self, split: str) -> 'Dataset':
r"""Get splitted dataset defined in `split_col` of :obj:`self.df`.
def get_split(self, split: str) -> 'Dataset':
r"""Returns a subset of the dataset that belongs to a given training
split (as defined in :obj:`split_col`).
Args:
split (str): The split name. Should be 'train', 'val', or 'test'.
split (str): The split name (either :obj:`"train"`, :obj:`"val"`,
or :obj:`"test"`.
"""
if self.split_col is None:
raise ValueError(
f"'get_split_dataset' is not supported for {self} "
"since 'split_col' is not specified.")
f"'get_split' is not supported for '{self}' since 'split_col' "
f"is not specified.")
if split not in ['train', 'val', 'test']:
raise ValueError(f"The split named {split} is not available. "
f"Needs to either 'train', 'val', or 'test'.")
raise ValueError(f"The split named '{split}' is not available. "
f"Needs to be either 'train', 'val', or 'test'.")
indices = self.df.index[self.df[self.split_col] ==
SPLIT_TO_NUM[split]].tolist()
return self[indices]

def split(self) -> Tuple['Dataset', 'Dataset', 'Dataset']:
r"""Splits the dataset into training, validation and test splits."""
return (
self.get_split('train'),
self.get_split('val'),
self.get_split('test'),
)

@property
@requires_post_materialization
def convert_to_tensor_frame(self) -> DataFrameToTensorFrameConverter:
Expand Down
14 changes: 13 additions & 1 deletion torch_frame/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@

class DataLoader(torch.utils.data.DataLoader):
r"""A data loader which creates mini-batches from a
:class:`torch_frame.Dataset` or :class:`torch_frame.TensorFrame`.
:class:`torch_frame.Dataset` or :class:`torch_frame.TensorFrame` object.
.. code-block:: python
import torch_frame
dataset = ...
loader = torch_frame.data.DataLoader(
dataset,
batch_size=512,
shuffle=True,
)
Args:
dataset (Dataset or TensorFrame): The dataset or tensor frame from
Expand Down
9 changes: 4 additions & 5 deletions torch_frame/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ class StatType(Enum):
r"""The different types for column statistics.
Attributes:
MEAN: Mean. Numerical column only.
STD: Standard deviation. Numerical column only.
MEAN: The average value of a numerical column.
STD: The standard deviation of a numerical column.
QUANTILES: The minimum, first quartile, median, third quartile,
and the maximum of the column. Numerical column only.
COUNT: The count of each class. Categorical column only.
and the maximum of a numerical column.
COUNT: The count of each category in a categorical column.
"""
# Numerical:
MEAN = 'MEAN'
Expand Down
64 changes: 50 additions & 14 deletions torch_frame/data/tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,54 @@
@dataclass(repr=False)
class TensorFrame:
r"""A tensor frame holds a :pytorch:`PyTorch` tensor for each table column.
Table columns are first organized into their semantic types (e.g.,
categorical, numerical) and then converted into their tensor
representation, which is stored as :obj:`feat_dict`. For instance,
:obj:`feat_dict[stype.numerical]` stores a concatenated :pytorch:`PyTorch`
tensor for all numerical features, where 0th/1st dim represents the
row/column in the original DataFrame, respectively.
:obj:`col_names_dict` stores column names of :obj:`feat_dict`. For example,
:obj:`col_names_dict[stype.numerical][i]` stores the column name of
:obj:`feat_dict[stype.numerical][:,i]`.
Additionally, TensorFrame can store the target values in :obj:`y`.
Table columns are organized into their semantic types
:class:`~torch_frame.stype` (*e.g.*, categorical, numerical) and mapped to
a compact tensor representation (*e.g.*, strings in a categorical column
are mapped to indices from :obj:`{0, ..., num_categories - 1}`), and can be
accessed through :obj:`feat_dict`.
For instance, :obj:`feat_dict[stype.numerical]` stores a concatenated
:pytorch:`PyTorch` tensor for all numerical features, where the first and
second dimension represents the row and column in the original data frame,
respectively.
:class:`TensorFrame` handles missing values via :obj:`float('NaN')` for
floating-point tensors, and :obj:`-1` otherwise.
:obj:`col_names_dict` maps each column in :obj:`feat_dict` to their
original column name.
For example, :obj:`col_names_dict[stype.numerical][i]` stores the column
name of :obj:`feat_dict[stype.numerical][:, i]`.
Additionally, :class:`TensorFrame` can store any target values in :obj:`y`.
.. code-block:: python
import torch_frame
tf = torch_frame.TensorFrame({
feat_dict = {
# Two numerical columns:
torch_frame.numerical: torch.randn(10, 2),
# Three categorical columns:
torch_frame.categorical: torch.randint(0, 5, (10, 3)),
},
col_names_dict = {
torch_frame.numerical: ['x', 'y'],
torch_frame.categorical: ['a', 'b', 'c'],
},
})
print(len(tf))
>>> 10
# Row-wise filtering:
tf = tf[torch.tensor([0, 2, 4, 6, 8])]
print(len(tf))
>>> 5
# Transfer tensor frame to the GPU:
tf = tf.to('cuda')
"""
feat_dict: Dict[torch_frame.stype, Tensor]
col_names_dict: Dict[torch_frame.stype, List[str]]
Expand All @@ -34,7 +70,7 @@ def __post_init__(self):
self.validate()

def validate(self):
r"""Validate the tensor frame object."""
r"""Validates the :class:`TensorFrame` object."""
if self.feat_dict.keys() != self.col_names_dict.keys():
raise RuntimeError(
f"The keys of feat_dict and col_names_dict must be the same, "
Expand Down Expand Up @@ -76,7 +112,7 @@ def validate(self):

@property
def stypes(self) -> List[stype]:
r"""Returns a canonical ordering of stypes in :obj:`feat_dict`"""
r"""Returns a canonical ordering of stypes in :obj:`feat_dict`."""
return list(
filter(lambda x: x in self.feat_dict, list(torch_frame.stype)))

Expand Down
25 changes: 12 additions & 13 deletions torch_frame/gbdt/gbdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def tune(self, tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int,
number of trials is specified by num_trials.
Args:
tf_train (TensorFrame): The train data in :obj:`TensorFrame`.
tf_val (TensorFrame): The validation data in :obj:`TensorFrame`.
num_trials (int): Number of trials to perform hyperparameter
tf_train (TensorFrame): The train data in :class:`TensorFrame`.
tf_val (TensorFrame): The validation data in :class:`TensorFrame`.
num_trials (int): Number of trials to perform hyper-parameter
search.
"""
if tf_train.y is None:
Expand All @@ -55,17 +55,16 @@ def tune(self, tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int,
self._is_fitted = True

def predict(self, tf_test: TensorFrame) -> Tensor:
r"""Predict the labels/values of the test data on the fitted model.
r"""Predict the labels/values of the test data on the fitted model and
returns its predictions:
Returns:
prediction (Tensor): The prediction output :obj:`Tensor` on the
fitted model. Prediction depends on the task type.
- If regression, pred contains numerical value prediction.
- If binary classification, pred contains the probability of
being positive.
- If multi-class classification, pred contains the class label
predictions.
- :obj:`TaskType.REGRESSION`: Returns raw numerical values.
- :obj:`TaskType.BINARY_CLASSIFICATION`: Returns the probability of
being positive.
- :obj:`TaskType.MULTICLASS_CLASSIFICATION`: Returns the class label
predictions.
"""
if not self.is_fitted:
raise RuntimeError(
Expand Down
Loading

0 comments on commit 4c5d3a9

Please sign in to comment.