-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Logging token level losses at inference time #4169
Logging token level losses at inference time #4169
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the well-documented PR description @c-flaherty and nice job so far!! this is going to be a super useful feature.
i didn't finish reviewing but found some things that need fixing so i wanted to unblock you before EOD!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great! this is looking so much simpler. added a few more questions/comments!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks nearly good to go, but need one last fix in beam search.
i'll also add one last request that you add a test that would capture some of the bugs you ran into here. For example, if you initialize the BeamSearch
object and pass a set of (fake) logprobs to select_paths
, can you ensure that it returns you the correct token scores? (& similarly for the other tree search objects)
I couldn't get the code to go any faster after playing around with some suggestions from Stephen, so instead I moved all token level logging related code behind guard statements. This means that none of the tensor operations related to token level logging run unless verbose mode is on. As expected, this change means that running this code without verbose mode on is not any slower than running this code on On
On
On
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for persisting through several rounds of edits @c-flaherty!!! This will be a super helpful change 😄
This looks good to go from my perspective, pending tests passing.
@stephenroller -- did you want to take a look before merge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to confirm -- do the rag tests you skip pass locally?
tests/nightly/gpu/test_bb2.py
Outdated
@@ -88,8 +88,9 @@ def test_retrieval_none(self): | |||
_test_bb2_rag(KnowledgeAccessMethod.NONE, n_docs=1) | |||
|
|||
|
|||
@testing_utils.skipUnlessGPU | |||
@unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive") | |||
# @testing_utils.skipUnlessGPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you may have left these comments in by mistake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just testing to right now to see if skipping these tests fixes the cache issue. If it does, then yea will remove this comment, but if it doesn't, then I'll revert these changes related to skipping tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(And address the cache issue in a separate pr)
tests/nightly/gpu/test_rag.py
Outdated
@@ -110,7 +110,7 @@ | |||
} | |||
|
|||
|
|||
@testing_utils.skipUnlessGPU | |||
@unittest.skip("Cache too large") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we actually need to skip these all of the time? Or just when they run on Circle CI? If so, you can use the decorator skipIfCircleCI
(in utils/testing.py
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, yeah I'll make that change if this indeed does fix the cache issue. I'll probably have to remove all these changes to the test decorators anyways though, as they probably won't fix the cache issues. Will keep this in mind!
Intro
In this PR, I update
TorchGeneratorAgent
to support logging token level conditional probabilities and ranks at inference time when inverbose
mode. This logging works whether the agent is using greedy, beam, or any other kind of token generation supported by the agent currently.Here are two examples:
parlai dm --model-file zoo:unittest/transformer_generator2/model --truncate 1024 -v --task integration_tests:multiturn_nocandidate -ne 1 --inference beam --beam-size 3
parlai dm --model-file zoo:unittest/transformer_generator2/model --truncate 1024 -v --task integration_tests:multiturn_nocandidate -ne 1 --inference greedy
A brief explanation with code pointers:
The scores are here:
ParlAI/parlai/core/torch_generator_agent.py
Line 1135 in 8094996
To my understanding score is a tensor of shape (batch size, num of beams, vocab size) and score[b, i, :] for example contains the conditional probabilities of each token in vocab being the next token in the i-th beam of batch b. However, generation candidates are added to beam objects whenever an EOS token is found. Therefore, accumulating score does not really get us all the way there. We need to have each beam accumulate token probabilities of beams and store them whenever finished hypotheses are found. This requires
(1) adding an additional parameter to the TreeSearch object to store them:
ParlAI/parlai/core/torch_generator_agent.py
Lines 1276 to 1277 in 8094996
(2) updating TreeSearch:select_paths method to output token probabilities of next paths in beam:
ParlAI/parlai/core/torch_generator_agent.py
Line 1327 in 8094996
(3) updating store of beam token probabilities each time TreeSearch:advance method is called:
ParlAI/parlai/core/torch_generator_agent.py
Lines 1423 to 1436 in 8094996
and finally (4) storing token probabilities along with candidate utterances in TreeSearch:finished parameter:
ParlAI/parlai/core/torch_generator_agent.py
Lines 1438 to 1450 in 8094996
Once, we do this, we will be able to easily able to pass through token probabilities through get_rescored_finished (the method introducing length penalty to utterance level probability) while keeping them stored alongside candidates, output them from TorchGeneratorAgent:_generate in both beam_preds_scores and in beams for free, and finally assign them to token_losses variable in TorchGeneratorAgent:eval_step, so we can output them in the same way they are outputted for examples with labels
Tests:
pytest tests/test_tga.py