Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: direct interface for active.cc and variable rename for understandability #4671

Merged
merged 11 commits into from
Mar 7, 2024
19 changes: 6 additions & 13 deletions test/train-sets/ref/active-simulation.t24.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,13 @@ Output pred = SCALAR
average since example example current current current
loss last counter weight label predict features
1.000000 1.000000 1 1.0 -1.0000 0.0000 128
0.791125 0.755288 2 6.8 -1.0000 -0.1309 44
1.274829 1.444750 8 26.3 1.0000 -0.2020 34
1.083985 0.895011 73 52.8 1.0000 0.0214 21
0.887295 0.693362 130 106.3 -1.0000 -0.3071 146
0.788245 0.690009 233 213.6 -1.0000 0.0421 47
0.664628 0.541195 398 427.4 -1.0000 -0.1863 68
0.634406 0.604328 835 856.9 -1.0000 -0.4327 40

finished run
number of examples = 1000
weighted example sum = 1014.004519
weighted label sum = -68.618036
average loss = 0.630964
best constant = -0.067670
best constant's loss = 0.995421
weighted example sum = 1.000000
weighted label sum = -1.000000
average loss = 1.000000
best constant = -1.000000
best constant's loss = 0.000000
total feature number = 78739
total queries = 474
total queries = 1
8 changes: 6 additions & 2 deletions test/train-sets/ref/help.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ Weight Options:
[Reduction] Active Learning Options:
--active Enable active learning (type: bool, keep, necessary)
--simulation Active learning simulation mode (type: bool)
--mellowness arg Active learning mellowness parameter c_0. Default 8 (type: float,
default: 8, keep)
--direct Active learning via the tag and predictions interface. Tag should
start with "query?" to get query decision. Returned prediction
is either -1 for no or the importance weight for yes. (type:
bool)
--mellowness arg Active learning mellowness parameter c_0. Default 1. (type: float,
default: 1, keep)
[Reduction] Active Learning with Cover Options:
--active_cover Enable active learning with cover (type: bool, keep, necessary)
--mellowness arg Active learning mellowness parameter c_0 (type: float, default:
Expand Down
83 changes: 66 additions & 17 deletions vowpalwabbit/core/src/reductions/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,41 @@
using namespace VW::reductions;
namespace
{
float get_active_coin_bias(float k, float avg_loss, float g, float c0)
{
const float b = c0 * (std::log(k + 1.f) + 0.0001f) / (k + 0.0001f);
const float sb = std::sqrt(b);
float get_active_coin_bias(float example_count, float avg_loss, float alt_label_error_rate_diff, float mellowness)
{//implementation follows https://web.archive.org/web/20120525164352/http://books.nips.cc/papers/files/nips23/NIPS2010_0363.pdf
const float mellow_log_e_count_over_e_count = mellowness * (std::log(example_count + 1.f) + 0.0001f) / (example_count + 0.0001f);
const float sqrt_mellow_lecoec = std::sqrt(mellow_log_e_count_over_e_count);
// loss should be in [0,1]
avg_loss = VW::math::clamp(avg_loss, 0.f, 1.f);

const float sl = std::sqrt(avg_loss) + std::sqrt(avg_loss + g);
if (g <= sb * sl + b) { return 1; }
const float rs = (sl + std::sqrt(sl * sl + 4 * g)) / (2 * g);
return b * rs * rs;
const float sqrt_avg_loss_plus_sqrt_alt_loss = std::min(1.f, //std::sqrt(avg_loss) + // commented out because two square roots appears to conservative.
std::sqrt(avg_loss + alt_label_error_rate_diff));//emperical variance deflater.
//std::cout << "example_count = " << example_count << " avg_loss = " << avg_loss << " alt_label_error_rate_diff = " << alt_label_error_rate_diff << " mellowness = " << mellowness << " mlecoc = " << mellow_log_e_count_over_e_count
// << " sqrt_mellow_lecoec = " << sqrt_mellow_lecoec << " double sqrt = " << sqrt_avg_loss_plus_sqrt_alt_loss << std::endl;

if (alt_label_error_rate_diff <= sqrt_mellow_lecoec * sqrt_avg_loss_plus_sqrt_alt_loss//deflater in use.
+ mellow_log_e_count_over_e_count) { return 1; }
//old equation
// const float rs = (sqrt_avg_loss_plus_sqrt_alt_loss + std::sqrt(sqrt_avg_loss_plus_sqrt_alt_loss * sqrt_avg_loss_plus_sqrt_alt_loss + 4 * alt_label_error_rate_diff)) / (2 * alt_label_error_rate_diff);
// return mellow_log_e_count_over_e_count * rs * rs;
const float sqrt_s = (sqrt_mellow_lecoec + std::sqrt(mellow_log_e_count_over_e_count+4*alt_label_error_rate_diff*mellow_log_e_count_over_e_count)) / 2*alt_label_error_rate_diff;
// std::cout << "sqrt_s = " << sqrt_s << std::endl;
return sqrt_s*sqrt_s;
}

float query_decision(const active& a, float ec_revert_weight, float k)
float query_decision(const active& a, float updates_to_change_prediction, float example_count)
{
float bias;
if (k <= 1.f) { bias = 1.f; }
if (example_count <= 1.f) { bias = 1.f; }
else
{
const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / k) +
std::sqrt((1.f + 0.5f * std::log(k)) / (weighted_queries + 0.0001f));
bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0);
// const auto weighted_queries = static_cast<float>(a._shared_data->weighted_labeled_examples);
const float avg_loss = (static_cast<float>(a._shared_data->sum_loss) / example_count);
//+ std::sqrt((1.f + 0.5f * std::log(example_count)) / (weighted_queries + 0.0001f)); Commented this out, not following why we need it from the theory.
// std::cout << "avg_loss = " << avg_loss << " weighted_queries = " << weighted_queries << " sum_loss = " << a._shared_data->sum_loss << " example_count = " << example_count << std::endl;
bias = get_active_coin_bias(example_count, avg_loss, updates_to_change_prediction / example_count, a.active_c0);
}

// std::cout << "bias = " << bias << std::endl;
return (a._random_state->get_and_update_random() < bias) ? 1.f / bias : -1.f;
}

Expand Down Expand Up @@ -110,6 +120,34 @@
}
}

template <bool is_learn>
void predict_or_learn_active_direct(active& a, learner& base, VW::example& ec)

Check warning on line 124 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L124

Added line #L124 was not covered by tests
{
if (is_learn) { base.learn(ec); }
else { base.predict(ec); }

Check warning on line 127 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L126-L127

Added lines #L126 - L127 were not covered by tests

if (ec.l.simple.label == FLT_MAX)

Check warning on line 129 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L129

Added line #L129 was not covered by tests
{
if (std::string(ec.tag.begin(), ec.tag.begin()+6) == "query?")

Check warning on line 131 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L131

Added line #L131 was not covered by tests
{
const float threshold = (a._shared_data->max_label + a._shared_data->min_label) * 0.5f;

Check warning on line 133 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L133

Added line #L133 was not covered by tests
// We want to understand the change in prediction if the label were to be
// the opposite of what was predicted. 0 and 1 are used for the expected min
// and max labels to be coming in from the active interactor.
ec.l.simple.label = (ec.pred.scalar >= threshold) ? a._min_seen_label : a._max_seen_label;
ec.confidence = std::abs(ec.pred.scalar - threshold) / base.sensitivity(ec);
ec.l.simple.label = FLT_MAX;
ec.pred.scalar = query_decision(a, ec.confidence, static_cast<float>(a._shared_data->weighted_unlabeled_examples));

Check warning on line 140 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L137-L140

Added lines #L137 - L140 were not covered by tests
}
}
else
{
// Update seen labels based on the current example's label.
a._min_seen_label = std::min(ec.l.simple.label, a._min_seen_label);
a._max_seen_label = std::max(ec.l.simple.label, a._max_seen_label);

Check warning on line 147 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L146-L147

Added lines #L146 - L147 were not covered by tests
}
}

Check warning on line 149 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L149

Added line #L149 was not covered by tests

void active_print_result(
VW::io::writer* f, float res, float weight, const VW::v_array<char>& tag, VW::io::logger& logger)
{
Expand Down Expand Up @@ -189,14 +227,16 @@

bool active_option = false;
bool simulation = false;
bool direct = false;
float active_c0;
option_group_definition new_options("[Reduction] Active Learning");
new_options.add(make_option("active", active_option).keep().necessary().help("Enable active learning"))
.add(make_option("simulation", simulation).help("Active learning simulation mode"))
.add(make_option("direct", direct).help("Active learning via the tag and predictions interface. Tag should start with \"query?\" to get query decision. Returned prediction is either -1 for no or the importance weight for yes."))
.add(make_option("mellowness", active_c0)
.keep()
.default_value(8.f)
.help("Active learning mellowness parameter c_0. Default 8"));
.default_value(1.f)
.help("Active learning mellowness parameter c_0. Default 1."));

if (!options.add_parse_and_check_necessary(new_options)) { return nullptr; }

Expand All @@ -223,6 +263,15 @@
print_update_func = VW::details::print_update_simple_label<active>;
reduction_name.append("-simulation");
}
else if (direct)
{
learn_func = predict_or_learn_active_direct<true>;
pred_func = predict_or_learn_active_direct<false>;
update_stats_func = update_stats_active;
output_example_prediction_func = VW::details::output_example_prediction_simple_label<active>;
print_update_func = VW::details::print_update_simple_label<active>;
learn_returns_prediction = base->learn_returns_prediction;

Check warning on line 273 in vowpalwabbit/core/src/reductions/active.cc

View check run for this annotation

Codecov / codecov/patch

vowpalwabbit/core/src/reductions/active.cc#L268-L273

Added lines #L268 - L273 were not covered by tests
}
else
{
all.reduction_state.active = true;
Expand Down
Loading