Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 97 additions & 16 deletions docs/examples/plot_linear_tree_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,38 @@
Handling Data with Many Labels Using Linear Methods
====================================================
For the case that the amount of labels is very large,
the training time of the standard ``train_1vsrest`` method may be unpleasantly long.
The ``train_tree`` method in LibMultiLabel can vastly improve the training time on such data sets.
For datasets with a very large number of labels, the training time of the standard ``train_1vsrest`` method can be prohibitively long. LibMultiLabel offers tree-based methods like ``train_tree`` and ``train_ensemble_tree`` to vastly improve training time in such scenarios.
To illustrate this speedup, we will use the `EUR-Lex dataset <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html#EUR-Lex>`_, which contains 3,956 labels.
The data in the following example is downloaded under the directory ``data/eur-lex``
Users can use the following command to easily apply the ``train_tree`` method.
.. code-block:: bash
$ python3 main.py --training_file data/eur-lex/train.txt
--test_file data/eur-lex/test.txt
--linear
--linear_technique tree
Besides CLI usage, users can also use API to apply ``train_tree`` method.
Below is an example.
We will use the `EUR-Lex dataset <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html#EUR-Lex>`_, which contains 3,956 labels. The data is assumed to be downloaded under the directory ``data/eur-lex``.
"""

import math
import libmultilabel.linear as linear
import time

# Load and preprocess the dataset
datasets = linear.load_dataset("txt", "data/eurlex/train.txt", "data/eurlex/test.txt")
preprocessor = linear.Preprocessor()
datasets = preprocessor.fit_transform(datasets)


######################################################################
# Standard Training and Prediction
# --------------------------------
#
# Users can use the following command to easily apply the ``train_tree`` method.
#
# .. code-block:: bash
#
# $ python3 main.py --training_file data/eur-lex/train.txt \\
# --test_file data/eur-lex/test.txt \\
# --linear \\
# --linear_technique tree
#
# Besides CLI usage, users can also use API to apply ``train_tree`` method.
# Below is an example.

training_start = time.time()
# the standard one-vs-rest method for multi-label problems
ovr_model = linear.train_1vsrest(datasets["train"]["y"], datasets["train"]["x"])
Expand Down Expand Up @@ -99,3 +102,81 @@ def metrics_in_batches(model):
print("Score of 1vsrest:", metrics_in_batches(ovr_model))
print("Score of tree:", metrics_in_batches(tree_model))


######################################################################
# Ensemble of Tree Models
# -----------------------
#
# While the ``train_tree`` method offers a significant speedup, its accuracy can sometimes be slightly lower than the standard one-vs-rest approach.
# The ``train_ensemble_tree`` method can help bridge this gap by training multiple tree models and averaging their predictions.
#
# Users can use the following command to easily apply the ``train_ensemble_tree`` method.
# The number of trees in the ensemble can be controlled with the ``--tree_ensemble_models`` argument.
#
# .. code-block:: bash
#
# $ python3 main.py --training_file data/eur-lex/train.txt \\
# --test_file data/eur-lex/test.txt \\
# --linear \\
# --linear_technique tree \\
# --tree_ensemble_models 3
#
# This command trains an ensemble of 3 tree models. If ``--tree_ensemble_models`` is not specified, it defaults to 1 (a single tree).
#
# Besides CLI usage, users can also use the API to apply the ``train_ensemble_tree`` method.
# Below is an example.

# We have already trained a single tree model as a baseline.
# Now, let's train an ensemble of 3 tree models.
training_start = time.time()
ensemble_model = linear.train_ensemble_tree(
datasets["train"]["y"], datasets["train"]["x"], n_trees=3
)
training_end = time.time()
print("Training time of ensemble tree: {:10.2f}".format(training_end - training_start))

######################################################################
# On a machine with an AMD-7950X CPU,
# the ``train_ensemble_tree`` function with 3 trees took `421.15` seconds,
# while the single tree took `144.37` seconds.
# As expected, training an ensemble takes longer, roughly proportional to the number of trees.
#
# Now, let's see if this additional training time translates to better performance.
# We'll compute the same P@K metrics on the test set for both the single tree and the ensemble model.

# `tree_preds` and `target` are already computed in the previous section.
ensemble_preds = linear.predict_values(ensemble_model, datasets["test"]["x"])

# `tree_score` is already computed.
print("Score of single tree:", tree_score)

ensemble_score = linear.compute_metrics(ensemble_preds, target, ["P@1", "P@3", "P@5"])
print("Score of ensemble tree:", ensemble_score)

######################################################################
# While training an ensemble takes longer, it often leads to better predictive performance.
# The following table shows a comparison between a single tree and ensembles
# of 3, 10, and 15 trees on several benchmark datasets.
#
# .. table:: Benchmark Results for Single and Ensemble Tree Models (P@K in %)
#
# +---------------+-----------------+-------+-------+-------+
# | Dataset | Model | P@1 | P@3 | P@5 |
# +===============+=================+=======+=======+=======+
# | EURLex-4k | Single Tree | 82.35 | 68.98 | 57.62 |
# | +-----------------+-------+-------+-------+
# | | Ensemble-3 | 82.38 | 69.28 | 58.01 |
# | +-----------------+-------+-------+-------+
# | | Ensemble-10 | 82.74 | 69.66 | 58.39 |
# | +-----------------+-------+-------+-------+
# | | Ensemble-15 | 82.61 | 69.56 | 58.29 |
# +---------------+-----------------+-------+-------+-------+
# | EURLex-57k | Single Tree | 90.77 | 80.81 | 67.82 |
# | +-----------------+-------+-------+-------+
# | | Ensemble-3 | 91.02 | 81.06 | 68.26 |
# | +-----------------+-------+-------+-------+
# | | Ensemble-10 | 91.23 | 81.22 | 68.34 |
# | +-----------------+-------+-------+-------+
# | | Ensemble-15 | 91.25 | 81.31 | 68.34 |
# +---------------+-----------------+-------+-------+-------+

76 changes: 73 additions & 3 deletions libmultilabel/linear/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from . import linear

__all__ = ["train_tree", "TreeModel"]
__all__ = ["train_tree", "TreeModel", "train_ensemble_tree", "EnsembleTreeModel"]

DEFAULT_K = 100
DEFAULT_DMAX = 10


class Node:
Expand Down Expand Up @@ -198,8 +201,8 @@ def train_tree(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
options: str = "",
K=100,
dmax=10,
K=DEFAULT_K,
dmax=DEFAULT_DMAX,
verbose: bool = True,
) -> TreeModel:
"""Train a linear model for multi-label data using a divide-and-conquer strategy.
Expand Down Expand Up @@ -382,3 +385,70 @@ def visit(node):
node_ptr = np.cumsum([0] + list(map(lambda w: w.shape[1], weights)))

return model, node_ptr


class EnsembleTreeModel:
"""An ensemble of tree models.
The ensemble aggregates predictions from multiple trees to improve accuracy and robustness.
"""

def __init__(self, tree_models: list[TreeModel]):
"""
Args:
tree_models (list[TreeModel]): A list of trained tree models.
"""
self.name = "ensemble-tree"
self.tree_models = tree_models
self.multiclass = False

def predict_values(self, x: sparse.csr_matrix, beam_width: int = 10) -> np.ndarray:
"""Calculates the averaged probability estimates from all trees in the ensemble.

Args:
x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
beam_width (int, optional): Number of candidates considered during beam search for each tree. Defaults to 10.

Returns:
np.ndarray: A matrix with dimension number of instances * number of classes, containing averaged scores.
"""
all_predictions = [model.predict_values(x, beam_width) for model in self.tree_models]
return np.mean(all_predictions, axis=0)


def train_ensemble_tree(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
options: str = "",
K: int = DEFAULT_K,
dmax: int = DEFAULT_DMAX,
n_trees: int = 3,
verbose: bool = True,
seed: int = None,
) -> EnsembleTreeModel:
"""Trains an ensemble of tree models (Parabel/Bonsai-style).
Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
options (str, optional): The option string passed to liblinear. Defaults to ''.
K (int, optional): Maximum degree of nodes in the tree. Defaults to 100.
dmax (int, optional): Maximum depth of the tree. Defaults to 10.
n_trees (int, optional): Number of trees in the ensemble. Defaults to 3.
verbose (bool, optional): Output extra progress information. Defaults to True.
seed (int, optional): The base random seed for the ensemble. Defaults to None, which will use 42.

Returns:
EnsembleTreeModel: An ensemble model which can be used for prediction.
"""
if seed is None:
seed = 42

tree_models = []
for i in range(n_trees):
np.random.seed(seed + i)

tree_model = train_tree(y, x, options, K, dmax, verbose)
tree_models.append(tree_model)

print("Ensemble training completed.")

return EnsembleTreeModel(tree_models)
28 changes: 20 additions & 8 deletions linear_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import libmultilabel.linear as linear
from libmultilabel.common_utils import dump_log, is_multiclass_dataset
from libmultilabel.linear.tree import EnsembleTreeModel, TreeModel, train_ensemble_tree
from libmultilabel.linear.utils import LINEAR_TECHNIQUES


Expand All @@ -21,7 +22,7 @@ def linear_test(config, model, datasets, label_mapping):
scores = []

predict_kwargs = {}
if model.name == "tree":
if isinstance(model, (TreeModel, EnsembleTreeModel)):
predict_kwargs["beam_width"] = config.beam_width

for i in tqdm(range(ceil(num_instance / config.eval_batch_size))):
Expand All @@ -48,13 +49,24 @@ def linear_train(datasets, config):
if multiclass:
raise ValueError("Tree model should only be used with multilabel datasets.")

model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
)
if config.tree_ensemble_models > 1:
model = train_ensemble_tree(
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
n_trees=config.tree_ensemble_models,
seed=config.seed,
)
else:
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
datasets["train"]["x"],
options=config.liblinear_options,
K=config.tree_degree,
dmax=config.tree_max_depth,
)
else:
model = LINEAR_TECHNIQUES[config.linear_technique](
datasets["train"]["y"],
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ def add_all_arguments(parser):
parser.add_argument(
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
)
parser.add_argument(
"--tree_ensemble_models", type=int, default=1, help="Number of models in the tree ensemble (default: %(default)s)"
)
parser.add_argument(
"--beam_width",
type=int,
Expand Down