Skip to content

[python-package] Custom multiclass loss function doesn't work #4981

Closed
@rosmineb

Description

Description

Hello,

I'm getting strange results for custom multiclass loss functions. I implemented multiclass logloss as a custom loss function, and trained while evaluating on 3 validation sets: the training data, the training data in shuffled order, and a heldout set. I see the training loss decrease for the train set, but it does not decrease for the training data in a shuffled order. The loss shouldn't depend on the order of samples in the dataset, so I would expect the training data and the shuffled training data should have the same loss.

As a sanity check, I used the same loss function for XGBoost, which was able to successfully train an accurate model.

I saw this issue: #1644
which suggested that the problem is 0 hessians. I tried the suggestion there and I still had the same problem.

Am I doing something wrong? Or is this a bug?
Thanks for your help!

Reproducible example

from sklearn.datasets import make_blobs
import lightgbm as lgb
from sklearn.model_selection import train_test_split
import numpy as np
import pdb
from scipy.special import softmax
from sklearn.metrics import accuracy_score
import random
import xgboost as xgb

C = 2
num_class = 4
# Data is Guassian blob in each of the 4 quadrants, numbered clockwise
# upper right: 0
# lower right: 1
# lower left: 2
# upper left: 3
X, Y = make_blobs(1000, n_features=2, centers=[(C, C), (C, -C), (-C, -C), (-C, C)])
X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=.2, random_state=0)
# for debugging, get the training data shuffled in a random order
reorder = list(range(X_train.shape[0]))
random.shuffle(reorder)
X_train_shuf = X_train[reorder, :]
y_train_shuf = y_train[reorder]

# code for loss function
class Loss(object):
    # Use https://maxhalford.github.io/blog/lightgbm-focal-loss/ as a template for a binary loss function
    def __init__(self, num_class, flatten_type='F', model_type='lgb'):
        self.num_class = num_class if num_class != 1 else 2
        self.flatten_type = flatten_type
        self.model_type = model_type

    def loss_function(self, y_true, raw_pred):
        # calculate probs as softmax of raw_pred, then get the correct index according to y_true
        p = self.softmax(raw_pred.reshape(y_true.shape[0], -1))
        p = np.clip(p, 1e-15, 1 - 1e-15)
        pt = p[np.arange(y_true.shape[0]), y_true.astype(int)]
        val = -1 * np.log(pt)
        return val

    def init_score(self, y_true):
        class_count = np.array([y_true[y_true == i].shape[0] for i in range(self.num_class)])
        output = np.ones((y_true.shape[0], self.num_class)) * class_count / np.sum(class_count)
        output = np.clip(output, 1e-15, 1 - 1e-15)
        return output.reshape(-1)

    def feval(self, preds, data):
        y_true = data.get_label()
        probs = preds.reshape((y_true.shape[0], -1))
        is_higher_better = False
        if self.model_type == 'lgb':
            return 'loss', self.loss_function(y_true, probs).mean(), is_higher_better
        else:
            return 'loss', self.loss_function(y_true, probs).mean()

    def predict(self, model, X, y_fit):
        return self.softmax(self.init_score(y_fit).reshape((-1, self.num_class))[0] + model.predict(X))

    def objective(self, raw_preds, data):
        y_true = data.get_label()
        probs = raw_preds.reshape((y_true.shape[0], -1))
        return self.grad(y_true, probs), self.hess(y_true, probs)

    def grad(self, y_true, y_pred):
        # To avoid errors in calculating derivatives, use numerical approximations
        func = lambda x : self.loss_function(y_true, x)
        grad = np.asarray([numerical_gradient_vectorized(func, point=y_pred, i=i) for i in range(self.num_class)])
        grad = grad.flatten(self.flatten_type)
        return grad

    def hess(self, y_true, y_pred):
        # To avoid errors in calculating derivatives, use numerical approximations
        func = lambda z: self.loss_function(y_true=y_true, raw_pred=z)
        diag_hess = [numerical_second_derivative_vectorized(func, y_pred, i, i) for i in range(self.num_class)]
        diag_hess = np.asarray(diag_hess)
        # adding 1e-6 is suggested to avoid 0 Hessians here: https://github.com/Microsoft/LightGBM/issues/1644
        # Although it does not seem to make an impact whether or not I add 1e-6
        return diag_hess.flatten(self.flatten_type) + 1e-6
    
    def softmax(self, z):
        return np.exp(z) / np.sum(np.exp(z), axis=1).reshape((-1, 1))
    
def numerical_second_derivative_vectorized(func, point, i, j, eps=.0001):
    # assume point is a row vector
    point = point.astype('float64') # without the higher precision, sometimes it fails
    ei = np.zeros_like(point).astype('float64')
    ej = np.zeros_like(point).astype('float64')
    ei[:, i] = eps
    ej[:, j] = eps
    return (func(point + ei + ej) - func(point+ ei) - func(point + ej) + func(point)) / eps ** 2

def numerical_gradient_vectorized(func, point, i, eps=1e-6):
    point = point.astype('float64')
    ei = np.zeros_like(point).astype('float64')
    ei[:, i] = eps
    return (func(point + ei) - func(point)) / eps

# Training LightGBM model
lgb_loss = Loss(num_class=num_class, model_type='lgb')
train_dataset = lgb.Dataset(X_train, y_train)
train_dataset_shuf = lgb.Dataset(X_train_shuf, y_train_shuf)
val_dataset = lgb.Dataset(X_val, y_val)
lgb_model = lgb.train(params={'n_estimators': 50, 'num_leaves': 300, 'num_class': num_class,
                          'min_data_in_leaf': 1},
                  train_set=train_dataset,
                  valid_sets=(train_dataset, train_dataset_shuf, val_dataset),
                  valid_names=('train', 'train_shuf', 'val'),
                  fobj=lgb_loss.objective,
                  feval=lgb_loss.feval)
# Training loss decreases for the 'train' set. However, it does not decrease for the 'train_shuf' set, which 
# is the same training data, just in shuffled order, which makes me think that this is a bug.

def lgb_predict(model, X, y_train, lgb_loss=lgb_loss):
    # Must add the initialization score before making prediction, see here:
    # https://maxhalford.github.io/blog/lightgbm-focal-loss/
    raw_preds = lgb_loss.predict(model, X, y_train)
    return np.argmax(raw_preds, axis=1)

lgb_preds_train = lgb_predict(lgb_model, X_train, y_train)
lgb_preds_train_shuf = lgb_predict(lgb_model, X_train_shuf, y_train)
lgb_preds_val = lgb_predict(lgb_model, X_val, y_train)

print('lgb accuracy train:', accuracy_score(y_train, lgb_preds_train))
print('lgb accuracy train_shuf:', accuracy_score(y_train, lgb_preds_train_shuf))
print('lgb accuracy val:', accuracy_score(y_val, lgb_preds_val))

# As a sanity check that I've implemented the loss function correctly, train with XGBoost
xgb_loss = Loss(num_class=num_class, model_type='xgb')

dtrain = xgb.DMatrix(X_train, y_train)
dtrain_shuf = xgb.DMatrix(X_train_shuf, y_train_shuf)
dval = xgb.DMatrix(X_val, y_val)
xmodel = xgb.train({'num_class': 4, 'disable_default_eval_metric': True},
                   dtrain,
                   num_boost_round=100,
                   obj=xgb_loss.objective) #,
                   #feval=xgb_loss.feval),
                   #evals=[(dtrain, 'train'), (dval, 'val')])

print('XGB train accuracy: ', accuracy_score(y_train, xmodel.predict(dtrain)))
print('XGB train_shuf accuracy: ', accuracy_score(y_train_shuf, xmodel.predict(dtrain_shuf)))
print('XGB val accuracy: ', accuracy_score(y_val, xmodel.predict(dval)))

Environment info

OS: Amazon Linux 2 (EC2 instance)
python: 3.7.10
LightGBM version or commit hash: 3.3.2

Command(s) you used to install LightGBM

pip install lightgbm -U

Additional Comments

Shortened output of code:
Notice how train loss decreases, but the loss on the shuffled training data and heldout set both increase.

[LightGBM] [Warning] Using self-defined objective function
[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000355 seconds.
You can set force_col_wise=true to remove the overhead.
[LightGBM] [Info] Total Bins 510
[LightGBM] [Info] Number of data points in the train set: 800, number of used features: 2
[LightGBM] [Warning] Using self-defined objective function
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[1] train's loss: 1.01643 train_shuf's loss: 1.41085 val's loss: 1.44438
[2] train's loss: 0.779389 train_shuf's loss: 1.45818 val's loss: 1.49208
[3] train's loss: 0.610527 train_shuf's loss: 1.51741 val's loss: 1.539
[4] train's loss: 0.484318 train_shuf's loss: 1.58413 val's loss: 1.579
[5] train's loss: 0.387396 train_shuf's loss: 1.65603 val's loss: 1.62837
...
[48] train's loss: 0.000383055 train_shuf's loss: 4.8523 val's loss: 4.07682
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[49] train's loss: 0.000361902 train_shuf's loss: 4.88315 val's loss: 4.10556
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[LightGBM] [Warning] No further splits with positive gain, best gain: -inf
[50] train's loss: 0.00034246 train_shuf's loss: 4.91395 val's loss: 4.13448
lgb accuracy train: 0.23625
lgb accuracy train_shuf: 0.25
lgb accuracy val: 0.25

XGB train accuracy: 0.99
XGB train_shuf accuracy: 0.99
XGB val accuracy: 0.945

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions