Skip to content

Commit

Permalink
Use dynamic thresholds for Ranker predictions.
Browse files Browse the repository at this point in the history
Adds the ability to specify a threshold replacement to apply when predicting using an Assist-Ranker model.

For Contextual Search the threshold is read from a
FieldTrial variations parameter associated with the existing
ContextualSearchRankerQuery Feature.

For Translate, threshold setting is supported but unused.

BUG=899134

Change-Id: Ia0b664ce1e0949f755e7d8898fb2769fc91ae536
Reviewed-on: https://chromium-review.googlesource.com/c/1312296
Commit-Queue: Donn Denman <donnd@chromium.org>
Reviewed-by: Charles . <charleszhao@chromium.org>
Reviewed-by: Andrew Moylan <amoylan@chromium.org>
Cr-Commit-Position: refs/heads/master@{#612324}
  • Loading branch information
Donn Denman authored and Commit Bot committed Nov 29, 2018
1 parent f19bb47 commit b14d5c3
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 26 deletions.
4 changes: 4 additions & 0 deletions components/assist_ranker/base_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ GURL BasePredictor::GetModelUrl() const {
return GURL(config_.field_trial_url_param->Get());
}

float BasePredictor::GetPredictThresholdReplacement() const {
return config_.field_trial_threshold_replacement_param;
}

RankerExample BasePredictor::PreprocessExample(const RankerExample& example) {
if (ranker_model_->proto().has_metadata() &&
ranker_model_->proto().metadata().input_features_names_are_hex_hashes()) {
Expand Down
9 changes: 8 additions & 1 deletion components/assist_ranker/base_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@ class UkmEntryBuilder;

namespace assist_ranker {

// Value to use for when no prediction threshold replacement should be applied.
// See |GetPredictThresholdReplacement| method.
const float kNoPredictThresholdReplacement = 0.0;

class Feature;
class RankerExample;
class RankerModel;

// Predictors are objects that provide an interface for prediction, as well as
// encapsulate the logic for loading the model and logging. Sub-classes of
// BasePredictor implement an interface that depends on the nature of the
// suported model. Subclasses of BasePredictor will also need to implement an
// supported model. Subclasses of BasePredictor will also need to implement an
// Initialize method that will be called once the model is available, and a
// static validation function with the following signature:
//
Expand All @@ -49,6 +53,9 @@ class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {

// Returns the model URL.
GURL GetModelUrl() const;
// Returns the threshold to use for prediction, or
// kNoPredictThresholdReplacement to leave it unchanged.
float GetPredictThresholdReplacement() const;
// Returns the model name.
std::string GetModelName() const;

Expand Down
46 changes: 30 additions & 16 deletions components/assist_ranker/base_predictor_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,37 +54,41 @@ const base::Feature kTestRankerQuery{"TestRankerQuery",
const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, kTestUrlParamName, kTestDefaultModelUrl};

const PredictorConfig kTestPredictorConfig = PredictorConfig{
kTestModelName, kTestLoggingName, kTestUmaPrefixName, LOG_UKM,
&kFeatureWhitelist, &kTestRankerQuery, &kTestRankerUrl};
const PredictorConfig kTestPredictorConfig =
PredictorConfig{kTestModelName, kTestLoggingName,
kTestUmaPrefixName, LOG_UKM,
&kFeatureWhitelist, &kTestRankerQuery,
&kTestRankerUrl, kNoPredictThresholdReplacement};

// Class that implements virtual functions of the base class.
class FakePredictor : public BasePredictor {
public:
static std::unique_ptr<FakePredictor> Create();
// Creates a |FakePredictor| using the default config (from this file).
static std::unique_ptr<FakePredictor> Create() {
return Create(kTestPredictorConfig);
}
// Creates a |FakePredictor| using the |PredictorConfig| passed in
// |predictor_config|.
static std::unique_ptr<FakePredictor> Create(
PredictorConfig predictor_config);
~FakePredictor() override{};
// Validation will always succeed.
static RankerModelStatus ValidateModel(const RankerModel& model);
static RankerModelStatus ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}

protected:
// Not implementing any inference logic.
bool Initialize() override { return true; };

private:
FakePredictor(const PredictorConfig& config);
FakePredictor(const PredictorConfig& config) : BasePredictor(config) {}
DISALLOW_COPY_AND_ASSIGN(FakePredictor);
};

FakePredictor::FakePredictor(const PredictorConfig& config)
: BasePredictor(config) {}

RankerModelStatus FakePredictor::ValidateModel(const RankerModel& model) {
return RankerModelStatus::OK;
}

std::unique_ptr<FakePredictor> FakePredictor::Create() {
std::unique_ptr<FakePredictor> predictor(
new FakePredictor(kTestPredictorConfig));
std::unique_ptr<FakePredictor> FakePredictor::Create(
PredictorConfig predictor_config) {
std::unique_ptr<FakePredictor> predictor(new FakePredictor(predictor_config));
auto ranker_model = std::make_unique<RankerModel>();
auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
base::BindRepeating(&FakePredictor::ValidateModel),
Expand Down Expand Up @@ -184,4 +188,14 @@ TEST_F(BasePredictorTest, LogExampleToUkm) {
GetTestUkmRecorder()->EntryHasMetric(entries[0], kFeatureNotWhitelisted));
}

TEST_F(BasePredictorTest, GetPredictThresholdReplacement) {
float altered_threshold = 0.78f; // Arbitrary value.
const PredictorConfig altered_threshold_config{
kTestModelName, kTestLoggingName, kTestUmaPrefixName,
LOG_UKM, &kFeatureWhitelist, &kTestRankerQuery,
&kTestRankerUrl, altered_threshold};
auto predictor = FakePredictor::Create(altered_threshold_config);
EXPECT_EQ(altered_threshold, predictor->GetPredictThresholdReplacement());
}

} // namespace assist_ranker
10 changes: 9 additions & 1 deletion components/assist_ranker/binary_classifier_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ std::unique_ptr<BinaryClassifierPredictor> BinaryClassifierPredictor::Create(
const GURL& model_url = predictor->GetModelUrl();
DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
DVLOG(1) << "Model URL: " << model_url;
DVLOG(1) << "Using predict threshold replacement: "
<< predictor->GetPredictThresholdReplacement();
auto model_loader = std::make_unique<RankerModelLoaderImpl>(
base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
Expand All @@ -52,7 +54,13 @@ bool BinaryClassifierPredictor::Predict(const RankerExample& example,
return false;
}

*prediction = inference_module_->Predict(PreprocessExample(example));
float predict_threshold_replacement = GetPredictThresholdReplacement();
if (predict_threshold_replacement != kNoPredictThresholdReplacement) {
*prediction = inference_module_->PredictScore(PreprocessExample(example)) >=
predict_threshold_replacement;
} else {
*prediction = inference_module_->Predict(PreprocessExample(example));
}
DVLOG(1) << "Predictor " << GetModelName() << " predicted: " << *prediction;
return true;
}
Expand Down
34 changes: 33 additions & 1 deletion components/assist_ranker/binary_classifier_predictor_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BinaryClassifierPredictorTest : public ::testing::Test {
GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();

PredictorConfig GetConfig();
PredictorConfig GetConfig(float predictor_threshold_replacement);

protected:
const std::string feature_ = "feature";
Expand Down Expand Up @@ -66,9 +67,14 @@ const base::FeatureParam<std::string> kTestRankerUrl{
&kTestRankerQuery, "url-param-name", "https://default.model.url"};

PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
return GetConfig(kNoPredictThresholdReplacement);
}

PredictorConfig BinaryClassifierPredictorTest::GetConfig(
float predictor_threshold_replacement) {
PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
GetEmptyWhitelist(), &kTestRankerQuery,
&kTestRankerUrl);
&kTestRankerUrl, predictor_threshold_replacement);

return config;
}
Expand Down Expand Up @@ -171,4 +177,30 @@ TEST_F(BinaryClassifierPredictorTest,
EXPECT_LT(float_response, threshold_);
}

TEST_F(BinaryClassifierPredictorTest,
GenericLogisticRegressionPreprocessedModelReplacedThreshold) {
auto ranker_model = std::make_unique<RankerModel>();
auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
glr = GetSimpleLogisticRegressionModel();
glr.clear_weights();
glr.set_is_preprocessed_model(true);
(*glr.mutable_fullname_weights())[feature_] = weight_;

float high_threshold = 0.9; // Some high threshold.
auto predictor =
InitPredictor(std::move(ranker_model), GetConfig(high_threshold));
EXPECT_TRUE(predictor->IsReady());

RankerExample ranker_example;
auto& features = *ranker_example.mutable_features();
features[feature_].set_bool_value(true);
bool bool_response;
EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
EXPECT_FALSE(bool_response);
float float_response;
EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
EXPECT_GT(float_response, threshold_);
EXPECT_LT(float_response, high_threshold);
}

} // namespace assist_ranker
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class GenericLogisticRegressionInference {
// Returns a boolean decision given a RankerExample. Uses the same logic as
// PredictScore, and then applies the model decision threshold.
bool Predict(const RankerExample& example);
// Returns a score between 0 and 1 give a RankerExample.
// Returns a score between 0 and 1 given a RankerExample.
float PredictScore(const RankerExample& example);

private:
Expand Down
14 changes: 9 additions & 5 deletions components/assist_ranker/predictor_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@ struct PredictorConfig {
const LogType log_type,
const base::flat_set<std::string>* feature_whitelist,
const base::Feature* field_trial,
const base::FeatureParam<std::string>* field_trial_url_param)
const base::FeatureParam<std::string>* field_trial_url_param,
float field_trial_threshold_replacement_param)
: model_name(model_name),
logging_name(logging_name),
uma_prefix(uma_prefix),
log_type(log_type),
feature_whitelist(feature_whitelist),
field_trial(field_trial),
field_trial_url_param(field_trial_url_param) {}
const char* model_name;
const char* logging_name;
const char* uma_prefix;
field_trial_url_param(field_trial_url_param),
field_trial_threshold_replacement_param(
field_trial_threshold_replacement_param) {}
const char* const model_name;
const char* const logging_name;
const char* const uma_prefix;
const LogType log_type;
const base::flat_set<std::string>* feature_whitelist;
const base::Feature* field_trial;
const base::FeatureParam<std::string>* field_trial_url_param;
const float field_trial_threshold_replacement_param;
};

} // namespace assist_ranker
Expand Down
13 changes: 12 additions & 1 deletion components/assist_ranker/predictor_config_definitions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// found in the LICENSE file.

#include "components/assist_ranker/predictor_config_definitions.h"
#include "components/assist_ranker/base_predictor.h"

namespace assist_ranker {

Expand All @@ -28,6 +29,15 @@ GetContextualSearchRankerUrlFeatureParam() {
return kContextualSearchRankerUrl;
}

float GetContextualSearchRankerThresholdFeatureParam() {
static auto* kContextualSearchRankerThreshold =
new base::FeatureParam<double>(
&kContextualSearchRankerQuery,
"contextual-search-ranker-predict-threshold",
kNoPredictThresholdReplacement);
return static_cast<float>(kContextualSearchRankerThreshold->Get());
}

// NOTE: This list needs to be kept in sync with tools/metrics/ukm/ukm.xml!
// Only features within this list will be logged to UKM.
// TODO(chrome-ranker-team) Deprecate the whitelist once it is available through
Expand Down Expand Up @@ -77,7 +87,8 @@ const PredictorConfig GetContextualSearchPredictorConfig() {
kContextualSearchModelName, kContextualSearchLoggingName,
kContextualSearchUmaPrefixName, LOG_UKM,
GetContextualSearchFeatureWhitelist(), &kContextualSearchRankerQuery,
GetContextualSearchRankerUrlFeatureParam()));
GetContextualSearchRankerUrlFeatureParam(),
GetContextualSearchRankerThresholdFeatureParam()));
return kContextualSearchPredictorConfig;
}
#endif // OS_ANDROID
Expand Down

0 comments on commit b14d5c3

Please sign in to comment.