forked from chromium/chromium
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeneric_logistic_regression_inference.h
37 lines (30 loc) · 1.37 KB
/
generic_logistic_regression_inference.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// Copyright 2017 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_ASSIST_RANKER_GENERIC_LOGISTIC_REGRESSION_INFERENCE_H_
#define COMPONENTS_ASSIST_RANKER_GENERIC_LOGISTIC_REGRESSION_INFERENCE_H_
#include "components/assist_ranker/proto/generic_logistic_regression_model.pb.h"
#include "components/assist_ranker/proto/ranker_example.pb.h"
namespace assist_ranker {
float Sigmoid(float x);
// TODO(hamelphi): Implement an interface base class
// BinaryClassificationInferenceModule.
//
// Implements inference for a GenericLogisticRegressionModel.
class GenericLogisticRegressionInference {
public:
explicit GenericLogisticRegressionInference(
GenericLogisticRegressionModel model_proto);
// Returns a boolean decision given a RankerExample. Uses the same logic as
// PredictScore, and then applies the model decision threshold.
bool Predict(const RankerExample& example);
// Returns a score between 0 and 1 give a RankerExample.
float PredictScore(const RankerExample& example);
private:
// Returns the decision threshold. If no threshold is specified in the proto,
// use 0.5.
float GetThreshold();
const GenericLogisticRegressionModel proto_;
};
} // namespace assist_ranker
#endif // COMPONENTS_ASSIST_RANKER_GENERIC_LOGISTIC_REGRESSION_INFERENCE_H_