Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 125 additions & 21 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import logging
import os
import psutil
import threading
import queue
import re

import numpy as np
import scipy.sparse as sparse
from liblinear.liblinearutil import train, problem, parameter, solver_names
from tqdm import tqdm

from ctypes import c_double

__all__ = [
"train_1vsrest",
"train_thresholding",
Expand Down Expand Up @@ -86,14 +92,114 @@ def _to_dense_array(self, matrix: np.matrix | sparse.csr_matrix) -> np.ndarray:
return np.asarray(matrix)


class ParallelOVRTrainer(threading.Thread):
"""A trainer for parallel 1vsrest training."""

y: sparse.csc_matrix
x: sparse.csr_matrix
bias: float
prob: problem
param: parameter
weights: np.ndarray
pbar: tqdm
queue: queue.SimpleQueue

def __init__(self):
threading.Thread.__init__(self)

@classmethod
def init_trainer(
cls,
y: sparse.csr_matrix,
x: sparse.csr_matrix,
options: str,
verbose: bool,
):
"""Initialize the parallel trainer by setting y, x, parameter and threading related
variables as class variables of ParallelOVRTrainer.

Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
options (str): The option string passed to liblinear.
verbose (bool): Output extra progress information.
"""
x, options, bias = _prepare_options(x, options)
cls.y = y.tocsc()
cls.x = x
cls.bias = bias
num_instances, num_classes = cls.y.shape
num_features = cls.x.shape[1]
cls.prob = problem(np.ones((num_instances,)), cls.x)

# remove "-m nr_thread" from options to prevent nested multi-threading
cls.param = parameter(re.sub(r"-m\s+\d+", "", options))
if cls.param.solver_type in [solver_names.L2R_L1LOSS_SVC_DUAL, solver_names.L2R_L2LOSS_SVC_DUAL]:
cls.param.w_recalc = True # only works for solving L1/L2-SVM dual
cls.weights = np.zeros((num_features, num_classes), order="F")
cls.queue = queue.SimpleQueue()

if verbose:
logging.info(f"Training a one-vs-rest model on {num_classes} labels")
for i in range(num_classes):
cls.queue.put(i)
cls.pbar = tqdm(total=num_classes, disable=not verbose)

@classmethod
def del_trainer(cls):
cls.pbar.close()
for key in list(cls.__annotations__):
delattr(cls, key)

def _do_parallel_train(self, y: np.ndarray) -> np.matrix:
"""Wrap around liblinear.liblinearutil.train.

Args:
y (np.ndarray): A +1/-1 array with dimensions number of instances * 1.

Returns:
np.matrix: The weights.
"""
if y.shape[0] == 0:
return np.matrix(np.zeros((self.prob.n, 1)))

prob = self.prob.copy()
prob.y = (c_double * prob.l)(*y)
model = train(prob, self.param)

w = np.ctypeslib.as_array(model.w, (self.prob.n, 1))
w = np.asmatrix(w)
# When all labels are -1, we must flip the sign of the weights
# because LIBLINEAR treats the first label as positive, which
# is -1 in this case. But for our usage we need them to be negative.
# For data with both +1 and -1 for labels, LIBLINEAR guarantees
# that +1 is always the first label.
if model.get_labels()[0] == -1:
return -w
else:
# The memory is freed on model deletion so we make a copy.
return w.copy()

def run(self):
while True:
try:
label_idx = self.queue.get_nowait()
except queue.Empty:
break
yi = self.y[:, label_idx].toarray().reshape(-1)
self.weights[:, label_idx] = self._do_parallel_train(2 * yi - 1).ravel()

self.pbar.update()


def train_1vsrest(
y: sparse.csr_matrix,
x: sparse.csr_matrix,
multiclass: bool = False,
options: str = "",
verbose: bool = True,
) -> FlatModel:
"""Train a linear model for multi-label data using a one-vs-rest strategy.
"""Train a linear model parallel on labels for multi-label data using a one-vs-rest strategy.

Args:
y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
Expand All @@ -106,18 +212,15 @@ def train_1vsrest(
A model which can be used in predict_values.
"""
# Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
x, options, bias = _prepare_options(x, options)

y = y.tocsc()
num_class = y.shape[1]
num_feature = x.shape[1]
weights = np.zeros((num_feature, num_class), order="F")

if verbose:
logging.info(f"Training one-vs-rest model on {num_class} labels")
for i in tqdm(range(num_class), disable=not verbose):
yi = y[:, i].toarray().reshape(-1)
weights[:, i] = _do_train(2 * yi - 1, x, options).ravel()
ParallelOVRTrainer.init_trainer(y, x, options, verbose)
num_threads = psutil.cpu_count(logical=False)
trainers = [ParallelOVRTrainer() for _ in range(num_threads)]
for trainer in trainers:
trainer.start()
for trainer in trainers:
trainer.join()
weights, bias = ParallelOVRTrainer.weights, ParallelOVRTrainer.bias
ParallelOVRTrainer.del_trainer()

return FlatModel(
name="1vsrest",
Expand Down Expand Up @@ -170,7 +273,7 @@ def _prepare_options(x: sparse.csr_matrix, options: str) -> tuple[sparse.csr_mat
if not "-q" in options_split:
options_split.append("-q")
if not "-m" in options:
options_split.append(f"-m {int(os.cpu_count() / 2)}")
options_split.append(f"-m {psutil.cpu_count(logical=False)}")

options = " ".join(options_split)
return x, options, bias
Expand Down Expand Up @@ -212,7 +315,7 @@ def train_thresholding(
thresholds = np.zeros(num_class)

if verbose:
logging.info("Training thresholding model on %s labels", num_class)
logging.info("Training a thresholding model on %s labels", num_class)

num_positives = np.sum(y, 2)
label_order = np.flip(np.argsort(num_positives)).flat
Expand Down Expand Up @@ -356,10 +459,11 @@ def _do_train(y: np.ndarray, x: sparse.csr_matrix, options: str) -> np.matrix:

w = np.ctypeslib.as_array(model.w, (x.shape[1], 1))
w = np.asmatrix(w)
# Liblinear flips +1/-1 labels so +1 is always the first label,
# but not if all labels are -1.
# For our usage, we need +1 to always be the first label,
# so the check is necessary.
# When all labels are -1, we must flip the sign of the weights
# because LIBLINEAR treats the first label as positive, which
# is -1 in this case. But for our usage we need them to be negative.
# For data with both +1 and -1, LIBLINEAR guarantees that +1
# is always the first label.
if model.get_labels()[0] == -1:
return -w
else:
Expand Down Expand Up @@ -440,7 +544,7 @@ def train_cost_sensitive(
weights = np.zeros((num_feature, num_class), order="F")

if verbose:
logging.info(f"Training cost-sensitive model for Macro-F1 on {num_class} labels")
logging.info(f"Training a cost-sensitive model for Macro-F1 on {num_class} labels")
for i in tqdm(range(num_class), disable=not verbose):
yi = y[:, i].toarray().reshape(-1)
w = _cost_sensitive_one_label(2 * yi - 1, x, options)
Expand Down Expand Up @@ -549,7 +653,7 @@ def train_cost_sensitive_micro(
bestScore = -np.inf

if verbose:
logging.info(f"Training cost-sensitive model for Micro-F1 on {num_class} labels")
logging.info(f"Training a cost-sensitive model for Micro-F1 on {num_class} labels")
for a in param_space:
tp = fn = fp = 0
for i in tqdm(range(num_class), disable=not verbose):
Expand Down