Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2s] distill t5-large -> t5-small #8376

Merged
merged 25 commits into from
Nov 11, 2020
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1ff83eb
Implement student model using different base model
sbhaktha Oct 13, 2020
758883a
changes
sbhaktha Oct 13, 2020
0ccf184
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Oct 13, 2020
ddf4d5d
Merging changes from master
sbhaktha Oct 14, 2020
ee878fe
Use teacher encoder outputs while calling teacher decoder
sbhaktha Oct 14, 2020
58a9a6e
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Oct 16, 2020
cfbc357
Return 0 tensor when hidden loss is not applicable, rename student_ba…
sbhaktha Oct 16, 2020
73c0a54
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 7, 2020
8bc75fd
Do not create student model in eval mode
sbhaktha Nov 7, 2020
708e457
Remove debug print
sbhaktha Nov 7, 2020
7bc3baa
Fix bugs causing failed tests
sbhaktha Nov 9, 2020
02fa544
Adding unit tests. Including code refactor per request on PR.
sbhaktha Nov 9, 2020
fb7a6f0
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 9, 2020
9432e67
Formatting changes per make fixup
sbhaktha Nov 9, 2020
65945d7
style
sshleifer Nov 9, 2020
937082e
Selectively unpack teacher encoder output and hidden states
sbhaktha Nov 10, 2020
9f3bc7c
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 10, 2020
65fc338
Merge branch 'add_student_base_model' of https://github.com/sbhaktha/…
sbhaktha Nov 10, 2020
d42ccbe
Merge branch 'master' into add_student_base_model
sshleifer Nov 11, 2020
2157d4d
style
sshleifer Nov 11, 2020
fcb2de3
fixup
sshleifer Nov 11, 2020
f7495de
fixed return_dict issue
sshleifer Nov 11, 2020
f313883
Merge branch 'master' into add_student_base_model
sshleifer Nov 11, 2020
26b3b91
style
sshleifer Nov 11, 2020
9858290
use_cache=False
sshleifer Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Return 0 tensor when hidden loss is not applicable, rename student_ba…
…se_model cmd line argument to student, get rid of print statements.
  • Loading branch information
sbhaktha committed Oct 16, 2020
commit cfbc357c115f1f2b4d9e663873beb3425dd8457a
29 changes: 7 additions & 22 deletions examples/seq2seq/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(self, hparams):
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default

e_layer_ids, d_layer_ids = None, None
if hparams.student_base_model is not None:
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student_base_model).eval()
if hparams.student is not None:
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student).eval()
use_task_specific_params(student, hparams.task)
else:
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self, hparams):
self.e_matches = None
self.d_matches = None

if hparams.student_base_model is None or hparams.teacher == hparams.student_base_model:
if hparams.student is None or hparams.teacher == hparams.student:
# Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
Expand Down Expand Up @@ -149,16 +149,13 @@ def add_model_specific_args(parser, root_dir):
def _step(self, batch):
# assert is_frozen(self.teacher) copied_decoder_layers
pad_token_id = self.tokenizer.pad_token_id
print(f'pad_token_id: {pad_token_id}, tokenizer: {type(self.tokenizer)}')
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
if isinstance(self.teacher, T5ForConditionalGeneration):
#print('Teacher model: T5ForConditionalGeneration')
teacher_decoder_input_ids = self.teacher._shift_right(labels)
else:
teacher_decoder_input_ids = shift_tokens_right(labels, pad_token_id)

if isinstance(self.model, T5ForConditionalGeneration):
#print('Student model: T5ForConditionalGeneration')
student_decoder_input_ids = self.model._shift_right(labels)
sbhaktha marked this conversation as resolved.
Show resolved Hide resolved
else:
student_decoder_input_ids = shift_tokens_right(labels, pad_token_id)
Expand Down Expand Up @@ -188,8 +185,7 @@ def _step(self, batch):
def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss)

teacher_enc_outputs =(enc_outputs,)

teacher_enc_outputs = enc_outputs
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: # compute encoder hidden state loss
with torch.no_grad():
Expand All @@ -210,16 +206,13 @@ def zero_tensor():
outputs = self.teacher(
input_ids,
attention_mask=teacher_mask,
encoder_outputs=teacher_enc_outputs,
encoder_outputs=(teacher_enc_outputs, ),
decoder_input_ids=teacher_decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
return_dict=True,
)
#print(f'outputs len: {len(outputs)}')
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
#print(tlogits)
#print(tdec_hidden)
dec_mask = teacher_decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
Expand Down Expand Up @@ -258,15 +251,14 @@ def maybe_calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, match
if matches:
return calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden)
else:
print('No matches, returning 0 for hidden loss')
return 0.0
return torch.tensor(0.0)

def add_distill_args(parser):
parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student_base_model", type=str, required=False)
parser.add_argument("--student", type=str, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
Expand Down Expand Up @@ -319,13 +311,6 @@ def distill_main(args):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

model = create_module(args)

#if args.student_base_model:
# print('Calling forward on teacher and student')
# teacher = ft_main(args, model=model.teacher)
# student = ft_main(args, model=model.model)
# print('Called forward on teacher and student')
#else:
return ft_main(args, model=model)


Expand Down