Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap mEstimators #117

Merged
merged 17 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions gtsam.h
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,9 @@ virtual class Null: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Fair: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1370,6 +1373,8 @@ virtual class Fair: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
};

virtual class Huber: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1378,6 +1383,20 @@ virtual class Huber: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Cauchy: gtsam::noiseModel::mEstimator::Base {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
Cauchy(double k);
static gtsam::noiseModel::mEstimator::Cauchy* Create(double k);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Tukey: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1386,6 +1405,9 @@ virtual class Tukey: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class Welsch: gtsam::noiseModel::mEstimator::Base {
Expand All @@ -1394,8 +1416,23 @@ virtual class Welsch: gtsam::noiseModel::mEstimator::Base {

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

virtual class GemanMcClure: gtsam::noiseModel::mEstimator::Base {
GemanMcClure(double k);
static gtsam::noiseModel::mEstimator::GemanMcClure* Create(double k);

// enabling serialization functionality
void serializable() const;

double weight(double error) const;
double residual(double error) const;
};

//TODO DCS and L2WithDeadZone mEstimators
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved

}///\namespace mEstimator

Expand Down
82 changes: 78 additions & 4 deletions gtsam/linear/NoiseModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,15 +718,25 @@ void Null::print(const std::string &s="") const
Null::shared_ptr Null::Create()
{ return shared_ptr(new Null()); }

/* ************************************************************************* */
// Fair
/* ************************************************************************* */

Fair::Fair(double c, const ReweightScheme reweight) : Base(reweight), c_(c) {
if (c_ <= 0) {
throw runtime_error("mEstimator Fair takes only positive double in constructor.");
}
}

/* ************************************************************************* */
// Fair
/* ************************************************************************* */
double Fair::weight(const double error) const {
return 1.0 / (1.0 + std::abs(error) / c_);
}
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
double Fair::residual(const double error) const {
const double absError = std::abs(error);
const double normalizedError = absError / c_;
const double c_2 = c_ * c_;
return c_2 * (normalizedError - std::log(1 + normalizedError));
}

void Fair::print(const std::string &s="") const
{ cout << s << "fair (" << c_ << ")" << endl; }
Expand All @@ -750,6 +760,20 @@ Huber::Huber(double k, const ReweightScheme reweight) : Base(reweight), k_(k) {
}
}

double Huber::weight(const double error) const {
const double absError = std::abs(error);
return (absError <= k_) ? (1.0) : (k_ / absError);
}

double Huber::residual(const double error) const {
const double absError = std::abs(error);
if (absError <= k_) { // |x| <= k
return error*error / 2;
} else { // |x| > k
return k_ * (absError - (k_/2));
}
}

void Huber::print(const std::string &s="") const {
cout << s << "huber (" << k_ << ")" << endl;
}
Expand All @@ -774,6 +798,16 @@ Cauchy::Cauchy(double k, const ReweightScheme reweight) : Base(reweight), k_(k),
}
}

double Cauchy::weight(const double error) const {
return ksquared_ / (ksquared_ + error*error);
}

double Cauchy::residual(const double error) const {
const double xc2 = error / k_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: xc2 is usually error*error/ksquared_, and then can do std::log(1 + xc2) below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varunagrawal I know you can't finish this yet because of other priorities, but could you incorporate this change into the new PR?

const double val = std::log(1 + (xc2*xc2));
return ksquared_ * val * 0.5;
}

void Cauchy::print(const std::string &s="") const {
cout << s << "cauchy (" << k_ << ")" << endl;
}
Expand All @@ -791,7 +825,30 @@ Cauchy::shared_ptr Cauchy::Create(double c, const ReweightScheme reweight) {
/* ************************************************************************* */
// Tukey
/* ************************************************************************* */
Tukey::Tukey(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

Tukey::Tukey(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {
if (c <= 0) {
throw runtime_error("mEstimator Tukey takes only positive double in constructor.");
}
}

double Tukey::weight(const double error) const {
if (std::abs(error) <= c_) {
const double xc2 = error*error/csquared_;
return (1.0-xc2)*(1.0-xc2);
}
return 0.0;
}
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
double Tukey::residual(const double error) const {
double absError = std::abs(error);
if (absError <= c_) {
const double xc2 = error*error/csquared_;
const double t = (1 - xc2)*(1 - xc2)*(1 - xc2);
return csquared_ * (1 - t) / 6.0;
} else {
return csquared_ / 6.0;
}
}

void Tukey::print(const std::string &s="") const {
std::cout << s << ": Tukey (" << c_ << ")" << std::endl;
Expand All @@ -810,8 +867,19 @@ Tukey::shared_ptr Tukey::Create(double c, const ReweightScheme reweight) {
/* ************************************************************************* */
// Welsch
/* ************************************************************************* */

Welsch::Welsch(double c, const ReweightScheme reweight) : Base(reweight), c_(c), csquared_(c * c) {}

double Welsch::weight(const double error) const {
const double xc2 = (error*error)/csquared_;
return std::exp(-xc2);
}

double Welsch::residual(const double error) const {
const double xc2 = (error*error)/csquared_;
return csquared_ * 0.5 * (1 - std::exp(-xc2) );
}

void Welsch::print(const std::string &s="") const {
std::cout << s << ": Welsch (" << c_ << ")" << std::endl;
}
Expand Down Expand Up @@ -858,6 +926,12 @@ double GemanMcClure::weight(double error) const {
return c4/(c2error*c2error);
}

double GemanMcClure::residual(double error) const {
const double c2 = c_*c_;
const double error2 = error*error;
return 0.5 * (c2 * error2) / (c2 + error2);
}

void GemanMcClure::print(const std::string &s="") const {
std::cout << s << ": Geman-McClure (" << c_ << ")" << std::endl;
}
Expand Down
79 changes: 23 additions & 56 deletions gtsam/linear/NoiseModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,9 @@ namespace gtsam {
* To illustrate, let's consider the least-squares (L2), L1, and Huber estimators as examples:
*
* Name Symbol Least-Squares L1-norm Huber
* Residual \rho(x) 0.5*x^2 |x| 0.5*x^2 if x<k, 0.5*k^2 + k|x-k| otherwise
* Derivative \phi(x) x sgn(x) x if x<k, k sgn(x) otherwise
* Weight w(x)=\phi(x)/x 1 1/|x| 1 if x<k, k/|x| otherwise
* Residual \rho(x) 0.5*x^2 |x| 0.5*x^2 if |x|<k, 0.5*k^2 + k|x-k| otherwise
* Derivative \phi(x) x sgn(x) x if |x|<k, k sgn(x) otherwise
* Weight w(x)=\phi(x)/x 1 1/|x| 1 if |x|<k, k/|x| otherwise
*
* With these definitions, D(\rho(x), p) = \phi(x) D(x,p) = w(x) x D(x,p) = w(x) D(L2(x), p),
* and hence we can solve the equivalent weighted least squares problem \sum w(r_i) \rho(r_i)
Expand Down Expand Up @@ -676,7 +676,7 @@ namespace gtsam {
* evaluating the total penalty. But for now, I'm leaving this residual method as pure
* virtual, so the existing mEstimators can inherit this default fallback behavior.
*/
virtual double residual(double error) const { return 0; };
virtual double residual(const double error) const { return 0; };

/*
* This method is responsible for returning the weight function for a given amount of error.
Expand All @@ -685,7 +685,7 @@ namespace gtsam {
* for details. This method is required when optimizing cost functions with robust penalties
* using iteratively re-weighted least squares.
*/
virtual double weight(double error) const = 0;
virtual double weight(const double error) const = 0;

virtual void print(const std::string &s) const = 0;
virtual bool equals(const Base& expected, double tol=1e-8) const = 0;
Expand Down Expand Up @@ -726,7 +726,8 @@ namespace gtsam {

Null(const ReweightScheme reweight = Block) : Base(reweight) {}
virtual ~Null() {}
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
virtual double weight(double /*error*/) const { return 1.0; }
virtual double weight(const double /*error*/) const { return 1.0; }
virtual double residual(const double error) const { return error; }
virtual void print(const std::string &s) const;
virtual bool equals(const Base& /*expected*/, double /*tol*/) const { return true; }
static shared_ptr Create() ;
Expand All @@ -749,9 +750,8 @@ namespace gtsam {
typedef boost::shared_ptr<Fair> shared_ptr;

Fair(double c = 1.3998, const ReweightScheme reweight = Block);
double weight(double error) const {
return 1.0 / (1.0 + std::abs(error) / c_);
}
double weight(const double error) const;
double residual(const double error) const;
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double c, const ReweightScheme reweight = Block) ;
Expand All @@ -775,10 +775,8 @@ namespace gtsam {
typedef boost::shared_ptr<Huber> shared_ptr;

Huber(double k = 1.345, const ReweightScheme reweight = Block);
double weight(double error) const {
double absError = std::abs(error);
return (absError < k_) ? (1.0) : (k_ / absError);
}
double weight(const double error) const;
double residual(const double error) const;
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
Expand Down Expand Up @@ -806,9 +804,8 @@ namespace gtsam {
typedef boost::shared_ptr<Cauchy> shared_ptr;

Cauchy(double k = 0.1, const ReweightScheme reweight = Block);
double weight(double error) const {
return ksquared_ / (ksquared_ + error*error);
}
double weight(const double error) const;
double residual(const double error) const;
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
Expand All @@ -832,13 +829,8 @@ namespace gtsam {
typedef boost::shared_ptr<Tukey> shared_ptr;

Tukey(double c = 4.6851, const ReweightScheme reweight = Block);
double weight(double error) const {
if (std::abs(error) <= c_) {
double xc2 = error*error/csquared_;
return (1.0-xc2)*(1.0-xc2);
}
return 0.0;
}
double weight(const double error) const;
double residual(const double error) const;
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
Expand All @@ -862,10 +854,8 @@ namespace gtsam {
typedef boost::shared_ptr<Welsch> shared_ptr;

Welsch(double c = 2.9846, const ReweightScheme reweight = Block);
double weight(double error) const {
double xc2 = (error*error)/csquared_;
return std::exp(-xc2);
}
double weight(const double error) const;
double residual(const double error) const;
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
Expand All @@ -885,31 +875,7 @@ namespace gtsam {
// Welsh implements the "Welsch" robust error model (Zhang97ivc)
// This was misspelled in previous versions of gtsam and should be
// removed in the future.
class GTSAM_EXPORT Welsh : public Base {
protected:
double c_, csquared_;

public:
typedef boost::shared_ptr<Welsh> shared_ptr;

Welsh(double c = 2.9846, const ReweightScheme reweight = Block);
double weight(double error) const {
double xc2 = (error*error)/csquared_;
return std::exp(-xc2);
}
void print(const std::string &s) const;
bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;

private:
/** Serialization function */
friend class boost::serialization::access;
template<class ARCHIVE>
void serialize(ARCHIVE & ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar & BOOST_SERIALIZATION_NVP(c_);
}
};
using Welsh = Welsch;
#endif

/// GemanMcClure implements the "Geman-McClure" robust error model
Expand All @@ -924,7 +890,8 @@ namespace gtsam {

GemanMcClure(double c = 1.0, const ReweightScheme reweight = Block);
virtual ~GemanMcClure() {}
virtual double weight(double error) const;
virtual double weight(const double error) const;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
virtual double residual(const double error) const;
virtual void print(const std::string &s) const;
virtual bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
Expand Down Expand Up @@ -953,7 +920,7 @@ namespace gtsam {

DCS(double c = 1.0, const ReweightScheme reweight = Block);
virtual ~DCS() {}
virtual double weight(double error) const;
virtual double weight(const double error) const;
virtual void print(const std::string &s) const;
virtual bool equals(const Base& expected, double tol=1e-8) const;
static shared_ptr Create(double k, const ReweightScheme reweight = Block) ;
Expand Down Expand Up @@ -984,11 +951,11 @@ namespace gtsam {
typedef boost::shared_ptr<L2WithDeadZone> shared_ptr;

L2WithDeadZone(double k, const ReweightScheme reweight = Block);
double residual(double error) const {
double residual(const double error) const {
const double abs_error = std::abs(error);
return (abs_error < k_) ? 0.0 : 0.5*(k_-abs_error)*(k_-abs_error);
}
double weight(double error) const {
double weight(const double error) const {
// note that this code is slightly uglier than above, because there are three distinct
// cases to handle (left of deadzone, deadzone, right of deadzone) instead of the two
// cases (deadzone, non-deadzone) above.
Expand Down
Loading