forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of hinge loss for binary classification (dmlc#3477)
- Loading branch information
1 parent
44811f2
commit 69454d9
Showing
5 changed files
with
94 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/*! | ||
* Copyright 2018 by Contributors | ||
* \file hinge.cc | ||
* \brief Provides an implementation of the hinge loss function | ||
* \author Henry Gouk | ||
*/ | ||
#include <xgboost/objective.h> | ||
#include "../common/math.h" | ||
|
||
namespace xgboost { | ||
namespace obj { | ||
|
||
DMLC_REGISTRY_FILE_TAG(hinge); | ||
|
||
class HingeObj : public ObjFunction { | ||
public: | ||
HingeObj() = default; | ||
|
||
void Configure( | ||
const std::vector<std::pair<std::string, std::string> > &args) override { | ||
// This objective does not take any parameters | ||
} | ||
|
||
void GetGradient(HostDeviceVector<bst_float> *preds, | ||
const MetaInfo &info, | ||
int iter, | ||
HostDeviceVector<GradientPair> *out_gpair) override { | ||
CHECK_NE(info.labels_.size(), 0U) << "label set cannot be empty"; | ||
CHECK_EQ(preds->Size(), info.labels_.size()) | ||
<< "labels are not correctly provided" | ||
<< "preds.size=" << preds->Size() | ||
<< ", label.size=" << info.labels_.size(); | ||
auto& preds_h = preds->HostVector(); | ||
|
||
out_gpair->Resize(preds_h.size()); | ||
auto& gpair = out_gpair->HostVector(); | ||
|
||
for (size_t i = 0; i < preds_h.size(); ++i) { | ||
auto y = info.labels_[i] * 2.0 - 1.0; | ||
bst_float p = preds_h[i]; | ||
bst_float w = info.GetWeight(i); | ||
bst_float g, h; | ||
if (p * y < 1.0) { | ||
g = -y * w; | ||
h = w; | ||
} else { | ||
g = 0.0; | ||
h = std::numeric_limits<bst_float>::min(); | ||
} | ||
gpair[i] = GradientPair(g, h); | ||
} | ||
} | ||
|
||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override { | ||
std::vector<bst_float> &preds = io_preds->HostVector(); | ||
for (auto& p : preds) { | ||
p = p > 0.0 ? 1.0 : 0.0; | ||
} | ||
} | ||
|
||
const char* DefaultEvalMetric() const override { | ||
return "error"; | ||
} | ||
}; | ||
|
||
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge") | ||
.describe("Hinge loss. Expects labels to be in [0,1f]") | ||
.set_body([]() { return new HingeObj(); }); | ||
|
||
} // namespace obj | ||
} // namespace xgboost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Copyright by Contributors | ||
#include <xgboost/objective.h> | ||
#include <limits> | ||
|
||
#include "../helpers.h" | ||
|
||
TEST(Objective, HingeObj) { | ||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:hinge"); | ||
std::vector<std::pair<std::string, std::string> > args; | ||
obj->Configure(args); | ||
xgboost::bst_float eps = std::numeric_limits<xgboost::bst_float>::min(); | ||
CheckObjFunction(obj, | ||
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f}, | ||
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f}, | ||
{ 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, | ||
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f}, | ||
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps }); | ||
|
||
ASSERT_NO_THROW(obj->DefaultEvalMetric()); | ||
} |