Skip to content

Commit

Permalink
model communicates logits rather than bools (good/bad)
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Suda committed Oct 16, 2020
1 parent 5473196 commit 110f414
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 14 deletions.
4 changes: 2 additions & 2 deletions Kernel/Clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Clause::Clause(unsigned length,const Inference& inf)
_component(false),
_store(NONE),
_numSelected(0),
_modelSaidYes(1), // be optimistic by default (delayed eval takes care of demoting the bad guys)
_modelSaid(std::numeric_limits<decltype(_modelSaid)>::lowest()), // be optimistic by default (delayed eval takes care of demoting the bad guys)
_weight(0),
_weightForClauseSelection(0),
_refCnt(0),
Expand Down Expand Up @@ -439,7 +439,7 @@ vstring Clause::toString() const
}

if(env.options->evalForKarel()) {
result += ",msY:" + Int::toString(_modelSaidYes);
result += ",msY:" + Int::toString(_modelSaid);
}

result += ",thAx:" + Int::toString((int)(_inference.th_ancestors));
Expand Down
10 changes: 5 additions & 5 deletions Kernel/Clause.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,12 @@ class Clause

unsigned numPositiveLiterals(); // number of positive literals in the clause

void modelSaid(bool value) {
_modelSaidYes = value;
void setModelSaid(float value) {
_modelSaid = value;
}

bool modelSaidYes() const {
return _modelSaidYes;
float modelSaid() const {
return _modelSaid;
}

protected:
Expand All @@ -374,7 +374,7 @@ class Clause
/** number of selected literals */
unsigned _numSelected : 20;

unsigned _modelSaidYes : 1;
float _modelSaid; // small is good

/** weight */
mutable unsigned _weight;
Expand Down
13 changes: 9 additions & 4 deletions Saturation/PredicateSplitPassiveClauseContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ void PredicateSplitPassiveClauseContainer::add(Clause* cl)
ASS(cl->store() == Clause::PASSIVE);

auto bestQueueIndex = bestQueue(evaluateFeature(cl));

if (_layeredArrangement)
{
// add clause to all queues starting from best queue for clause
Expand Down Expand Up @@ -236,11 +237,12 @@ Clause* PredicateSplitPassiveClauseContainer::popSelected()

if (currIndex < (long int)_queues.size()-1 && // not the last index
_delayedEvaluator) { // we can re-evaluate

_delayedEvaluator(cl);

if (evaluateFeature(cl) > _cutoffs[currIndex]) {
// we don't like the clause here!
// cout << "Didn't like " << cl->number() << " in " << currIndex << endl;
//cout << "Didn't like " << cl->number() << " in " << currIndex << endl;
goto search_for_an_appropriate_queue;
}
}
Expand Down Expand Up @@ -566,20 +568,23 @@ float PositiveLiteralMultiSplitPassiveClauseContainer::evaluateFeatureEstimate(u

NeuralEvalSplitPassiveClauseContainer::NeuralEvalSplitPassiveClauseContainer(bool isOutermost, const Shell::Options &opt, Lib::vstring name, Lib::vvector<std::unique_ptr<PassiveClauseContainer>> queues) :
PredicateSplitPassiveClauseContainer(isOutermost, opt, name, std::move(queues),
Lib::vvector<float>({0.0, std::numeric_limits<float>::max()}),
opt.neuralEvalSplitQueueCutoffs(),
opt.neuralEvalSplitQueueRatios(),
true /* monotone queue split hard-wired here */) {}

float NeuralEvalSplitPassiveClauseContainer::evaluateFeature(Clause* cl) const
{
CALL("NeuralEvalSplitPassiveClauseContainer::evaluateFeature");
return 1.0-(float)cl->modelSaidYes(); // 0.0 is good, 1.0 is bad (because the hard-wired < comparison in PredicateSplitPassiveClauseContainer)

// cout << "evaluateFeature " << cl->number() << " " << cl->modelSaid() << endl;

return cl->modelSaid(); // small is good, large is bad
}

float NeuralEvalSplitPassiveClauseContainer::evaluateFeatureEstimate(unsigned numPositiveLiterals, const Inference& inference) const
{
CALL("NeuralEvalSplitPassiveClauseContainer::evaluateFeatureEstimate");
return 0.0; // simply estimate that the clause is good
return std::numeric_limits<float>::lowest(); // simply estimate that the clause is good
}

};
6 changes: 3 additions & 3 deletions Saturation/SaturationAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ static std::unique_ptr<PassiveClauseContainer> makeLevel5(bool isOutermost, cons
if (opt.useNeuralEvalSplitQueues())
{
Lib::vvector<std::unique_ptr<PassiveClauseContainer>> queues;
Lib::vvector<float> cutoffs = opt.positiveLiteralSplitQueueCutoffs();
Lib::vvector<float> cutoffs = opt.neuralEvalSplitQueueCutoffs();
for (unsigned i = 0; i < cutoffs.size(); i++)
{
auto queueName = name + "NESQ" + Int::toString(cutoffs[i]) + ":";
Expand Down Expand Up @@ -540,7 +540,7 @@ void SaturationAlgorithm::embed_and_evaluate(Clause* cl, const char* method_name
inputs.push_back(id);

auto out = _model.forward(inputs);
cl->modelSaid(out.toBool());
cl->setModelSaid(-out.toDouble()); // already here, we reverse the logit's logic to "small is good"!
}
}

Expand Down Expand Up @@ -1367,7 +1367,7 @@ void SaturationAlgorithm::addToPassive(Clause* cl)

if (_opt.evalForKarel()) {
// talkToKarel(cl); // delayed evaluation trick (TODO: do this for initial as well?)
cl->modelSaid(true);
// cl->modelSaid(true); // the clause is born as good; see Clause::Clause
}

{
Expand Down
33 changes: 33 additions & 0 deletions Shell/Options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,12 @@ void Options::Options::init()
_useNeuralEvalSplitQueues.reliesOn(_evalForKarel.is(notEqual(vstring(""))));
_useNeuralEvalSplitQueues.tag(OptionTag::SATURATION);

_neuralEvalSplitQueueCutoffs = StringOptionValue("neural_eval_split_queue_cutoffs", "nesqc", "0");
_neuralEvalSplitQueueCutoffs.description = "The cutoff-values for the neural-eval-split-queues (the cutoff value for the last queue is omitted, since it has to be infinity).";
_lookup.insert(&_neuralEvalSplitQueueCutoffs);
_neuralEvalSplitQueueCutoffs.reliesOn(_useNeuralEvalSplitQueues.is(equal(true)));
_neuralEvalSplitQueueCutoffs.tag(OptionTag::SATURATION);

_neuralEvalSplitQueueRatios = StringOptionValue("neural_eval_split_queue_ratios", "nesqr", "10,1");
_neuralEvalSplitQueueRatios.description = "The ratios for picking clauses from the neural-eval-split-queues using weighted round robin. If a queue is empty, the clause will be picked from the next non-empty queue to the right."
" There should be exactly two numbers in the list.";
Expand Down Expand Up @@ -3529,3 +3535,30 @@ Lib::vvector<int> Options::neuralEvalSplitQueueRatios() const
return inputRatios;
}

Lib::vvector<float> Options::neuralEvalSplitQueueCutoffs() const
{
CALL("Options::neuralEvalSplitQueueCutoffs");
// initialize cutoffs and add float-max as last value
Lib::vvector<float> cutoffs;
vstringstream cutoffsStream(_neuralEvalSplitQueueCutoffs.actualValue);
std::string currentCutoff;
while (std::getline(cutoffsStream, currentCutoff, ','))
{
cutoffs.push_back(std::stof(currentCutoff));
}
cutoffs.push_back(std::numeric_limits<float>::max());

// sanity checks
for (unsigned i = 0; i < cutoffs.size(); i++)
{
auto cutoff = cutoffs[i];

if (i > 0 && cutoff <= cutoffs[i-1])
{
USER_ERROR("The cutoff values (supplied by option '-nesqc') must be strictly increasing");
}
}

return cutoffs;
}

2 changes: 2 additions & 0 deletions Shell/Options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,7 @@ bool _hard;

bool useNeuralEvalSplitQueues() const { return _useNeuralEvalSplitQueues.actualValue; }
Lib::vvector<int> neuralEvalSplitQueueRatios() const;
Lib::vvector<float> neuralEvalSplitQueueCutoffs() const;

void setWeightRatio(int v){ _ageWeightRatio.otherValue = v; }
AgeWeightRatioShape ageWeightRatioShape() const { return _ageWeightRatioShape.actualValue; }
Expand Down Expand Up @@ -2413,6 +2414,7 @@ bool _hard;
StringOptionValue _positiveLiteralSplitQueueCutoffs;
BoolOptionValue _positiveLiteralSplitQueueLayeredArrangement;
BoolOptionValue _useNeuralEvalSplitQueues;
StringOptionValue _neuralEvalSplitQueueCutoffs;
StringOptionValue _neuralEvalSplitQueueRatios;

BoolOptionValue _randomAWR;
Expand Down

0 comments on commit 110f414

Please sign in to comment.