Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2s] distill t5-large -> t5-small #8376

Merged
merged 25 commits into from
Nov 11, 2020
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1ff83eb
Implement student model using different base model
sbhaktha Oct 13, 2020
758883a
changes
sbhaktha Oct 13, 2020
0ccf184
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Oct 13, 2020
ddf4d5d
Merging changes from master
sbhaktha Oct 14, 2020
ee878fe
Use teacher encoder outputs while calling teacher decoder
sbhaktha Oct 14, 2020
58a9a6e
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Oct 16, 2020
cfbc357
Return 0 tensor when hidden loss is not applicable, rename student_ba…
sbhaktha Oct 16, 2020
73c0a54
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 7, 2020
8bc75fd
Do not create student model in eval mode
sbhaktha Nov 7, 2020
708e457
Remove debug print
sbhaktha Nov 7, 2020
7bc3baa
Fix bugs causing failed tests
sbhaktha Nov 9, 2020
02fa544
Adding unit tests. Including code refactor per request on PR.
sbhaktha Nov 9, 2020
fb7a6f0
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 9, 2020
9432e67
Formatting changes per make fixup
sbhaktha Nov 9, 2020
65945d7
style
sshleifer Nov 9, 2020
937082e
Selectively unpack teacher encoder output and hidden states
sbhaktha Nov 10, 2020
9f3bc7c
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 10, 2020
65fc338
Merge branch 'add_student_base_model' of https://github.com/sbhaktha/…
sbhaktha Nov 10, 2020
d42ccbe
Merge branch 'master' into add_student_base_model
sshleifer Nov 11, 2020
2157d4d
style
sshleifer Nov 11, 2020
fcb2de3
fixup
sshleifer Nov 11, 2020
f7495de
fixed return_dict issue
sshleifer Nov 11, 2020
f313883
Merge branch 'master' into add_student_base_model
sshleifer Nov 11, 2020
26b3b91
style
sshleifer Nov 11, 2020
9858290
use_cache=False
sshleifer Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style
  • Loading branch information
sshleifer committed Nov 11, 2020
commit 2157d4d1dfc716b4ddb0a30a55a2123f13d0e36a
6 changes: 4 additions & 2 deletions examples/seq2seq/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def zero_tensor():
input_ids,
attention_mask=src_mask,
output_hidden_states=not self.different_base_models,
return_dict=True
return_dict=True,
)
if self.different_base_models:
teacher_enc_outputs = teacher_encoder.last_hidden_state
Expand All @@ -225,7 +225,9 @@ def zero_tensor():
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, outputs.logits)
if (not self.different_base_models) and self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
if (
not self.different_base_models
) and self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
tdec_hidden = outputs.decoder_hidden_states
hid_loss_dec = self.calc_hidden_loss(
dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
Expand Down