Skip to content

Optional apply logprob computation at call site instead of construction site #71

Open
@jerinphilip

Description

@jerinphilip

This issue is meant to track the possibility of some workaround to get QE to be optional at run time, as opposed to construction time (skip-cost= true).

Trace:

skip-cost:

bool skipCost = options->get<bool>("skip-cost");
auto encdec = models::createModelFromOptions(
options, skipCost ? models::usage::raw : models::usage::translation);

createModelFromOptions:

// add (log)softmax if requested
if (use == usage::translation) {
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
if(options->get<bool>("output-sampling", false))
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
else
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());

StepWise:

// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,
// wrapped again in the IEncoderDecoder interface
// @TODO: seems we are conflating an interface defition with its implementation?
// @TODO: needs a better name. Stepwise is an adjective. Classes are things=nouns. StepwiseWhat?
class Stepwise : public IEncoderDecoder {

StepWise Relevant call site:

virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
const Words& words, // [beamIndex * activeBatchSize + batchIndex]
const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) override {
auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize);
return cost_->apply(nextState);
}

If I insert a bool skipCost defaulting to false as part of the arguments here and ignore the cost operation if skipCost=true and trigger the param via beamsearch (see below), there is a possibility?

Call-site:

states[i] = scorers_[i]->step(graph, states[i], hypIndices, prevWords, batchIndices, (int)maxBeamSize);

virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph,
Ptr<ScorerState> state,
const std::vector<IndexType>& hypIndices,
const Words& words,
const std::vector<IndexType>& batchIndices,
int beamSize) override {
graph->switchParams(getName());
auto wrapperState = std::dynamic_pointer_cast<ScorerWrapperState>(state);
auto newState = encdec_->step(graph, wrapperState->getState(), hypIndices, words, batchIndices, beamSize);
return New<ScorerWrapperState>(newState);

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions