Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2s] distill t5-large -> t5-small #8376

Merged
merged 25 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1ff83eb
Implement student model using different base model
sbhaktha Oct 13, 2020
758883a
changes
sbhaktha Oct 13, 2020
0ccf184
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Oct 13, 2020
ddf4d5d
Merging changes from master
sbhaktha Oct 14, 2020
ee878fe
Use teacher encoder outputs while calling teacher decoder
sbhaktha Oct 14, 2020
58a9a6e
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Oct 16, 2020
cfbc357
Return 0 tensor when hidden loss is not applicable, rename student_ba…
sbhaktha Oct 16, 2020
73c0a54
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 7, 2020
8bc75fd
Do not create student model in eval mode
sbhaktha Nov 7, 2020
708e457
Remove debug print
sbhaktha Nov 7, 2020
7bc3baa
Fix bugs causing failed tests
sbhaktha Nov 9, 2020
02fa544
Adding unit tests. Including code refactor per request on PR.
sbhaktha Nov 9, 2020
fb7a6f0
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 9, 2020
9432e67
Formatting changes per make fixup
sbhaktha Nov 9, 2020
65945d7
style
sshleifer Nov 9, 2020
937082e
Selectively unpack teacher encoder output and hidden states
sbhaktha Nov 10, 2020
9f3bc7c
Merge branch 'master' of https://github.com/huggingface/transformers …
sbhaktha Nov 10, 2020
65fc338
Merge branch 'add_student_base_model' of https://github.com/sbhaktha/…
sbhaktha Nov 10, 2020
d42ccbe
Merge branch 'master' into add_student_base_model
sshleifer Nov 11, 2020
2157d4d
style
sshleifer Nov 11, 2020
fcb2de3
fixup
sshleifer Nov 11, 2020
f7495de
fixed return_dict issue
sshleifer Nov 11, 2020
f313883
Merge branch 'master' into add_student_base_model
sshleifer Nov 11, 2020
26b3b91
style
sshleifer Nov 11, 2020
9858290
use_cache=False
sshleifer Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 71 additions & 30 deletions examples/seq2seq/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,45 @@ def __init__(self, hparams):
hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
)

e_layer_ids, d_layer_ids = None, None
if hparams.student is not None:
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student)
use_task_specific_params(student, hparams.task)
else:
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
)

if hparams.length_penalty != -1:
student.config.length_penalty = hparams.length_penalty
super().__init__(hparams, model=student, config=student.config)
model_type = student.config.model_type
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
student_model_type = student.config.model_type
teacher_model_type = teacher.config.model_type

student_encoder_layers, student_decoder_layers = None, None

if student_model_type == "t5":
student_encoder_layers = len(student.get_encoder().block)
student_decoder_layers = len(student.get_decoder().block)
else:
student_encoder_layers = student.config.encoder_layers
student_decoder_layers = student.config.decoder_layers

if model_type == "t5":
if teacher_model_type == "t5":
teacher_encoder_layers = len(teacher.get_encoder().block)
teacher_decoder_layers = len(teacher.get_decoder().block)
else:
teacher_encoder_layers = teacher.config.encoder_layers
teacher_decoder_layers = teacher.config.decoder_layers

self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
self.different_encoder = student_encoder_layers != teacher_encoder_layers

if e_layer_ids is None or d_layer_ids is None:
e_layer_ids = list(range(student_encoder_layers))
d_layer_ids = list(range(student_decoder_layers))

self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]

self.teacher = teacher
freeze_params(self.teacher)
Expand All @@ -67,13 +88,24 @@ def __init__(self, hparams):
del self.teacher.model.encoder
except AttributeError: # T5
del self.teacher.encoder
# Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
else: # student layer should emulate hidden states of the teacher layer it was copied from
self.e_matches = self.e_layer_ids
self.d_matches = self.d_layer_ids

self.e_matches = None
self.d_matches = None
self.do_calc_hidden_loss = False

if hparams.student is None or hparams.teacher == hparams.student:
# Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.e_matches = get_layers_to_supervise(
n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers
)
self.d_matches = get_layers_to_supervise(
n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers
)
else: # student layer should emulate hidden states of the teacher layer it was copied from
self.e_matches = self.e_layer_ids
self.d_matches = self.d_layer_ids
self.do_calc_hidden_loss = True

self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0
Expand Down Expand Up @@ -126,6 +158,7 @@ def _step(self, batch):
# assert is_frozen(self.teacher) copied_decoder_layers
pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]

if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(labels)
else:
Expand Down Expand Up @@ -156,26 +189,28 @@ def _step(self, batch):
def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss)

teacher_enc_outputs = enc_outputs
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: # compute encoder hidden state loss
with torch.no_grad():
teacher_enc_hid = self.teacher.get_encoder()(
input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
).hidden_states

hid_loss_enc = self.calc_hidden_loss(
src_mask,
enc_hidden_state,
teacher_enc_hid,
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)

teacher_enc_outputs, teacher_enc_hid = self.teacher.get_encoder()(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the old case of distilling to a copied student, this change uses more memory.
Trying to find a clean solution locally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Not sure what part of the change makes it use more memory? Is there any way I can help?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used to only extract hidden states, now we extract hidden states and encoder outputs.
Your application doesn't need hidden states, the other application doesn't need encoder outputs, so splitting the call to self.teacher.get_encoder()( to only get what's needed is the next step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it.I am attempting a fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not true that the other application doesn't need encoder outputs. When calling the teacher decoder, if student and teacher encoders are different, the right thing to do is to pass the teacher encoder outputs, and as we can see from these lines, even in the original case, we could have self.different_encoder be true if not all layers were copied from teacher to student:

self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
.
So I think we do need to extract teacher encoder outputs inside this if different_encoder block.
However, we can extract hidden states optionally based on whether we need to calculate hidden loss or not. I have the following modified code which passes the unit tests. Please lmk what you think.

        if self.different_encoder:  # compute encoder hidden state loss
            with torch.no_grad():
                teacher_encoder = self.teacher.get_encoder()(
                    input_ids,
                    attention_mask=src_mask,
                    output_hidden_states=self.do_calc_hidden_loss,
                    return_dict=True
                )
            teacher_enc_outputs = teacher_encoder.last_hidden_state
            if self.do_calc_hidden_loss:
                teacher_enc_hid = teacher_encoder.hidden_states
                hid_loss_enc = self.calc_hidden_loss(
                    src_mask,
                    enc_hidden_state,
                    teacher_enc_hid,
                    self.e_matches,
                    normalize_hidden=self.hparams.normalize_hidden,
                )

That said, the same boolean self.do_calc_hidden_loss can be used to set output_hidden_states to True or False in the call to the teacher decoder as well . If the above makes sense to you I'll make a similar change here as well and push a commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, thanks for the clarification! I'll take another look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I have another fix that's passing unit tests. Please lmk what you think. I renamed the self.do_calc_hidden_loss to self.different_base_models because really that's what it was keeping track of, and the boolean value is flipped from the earlier one.

        if self.different_encoder:  # compute encoder hidden state loss
            with torch.no_grad():
                teacher_encoder = self.teacher.get_encoder()(
                    input_ids,
                    attention_mask=src_mask,
                    output_hidden_states=not self.different_base_models,
                    return_dict=True
                )
            if self.different_base_models:
                teacher_enc_outputs = teacher_encoder.last_hidden_state
            else:
                teacher_enc_hid = teacher_encoder.hidden_states
                hid_loss_enc = self.calc_hidden_loss(
                    src_mask,
                    enc_hidden_state,
                    teacher_enc_hid,
                    self.e_matches,
                    normalize_hidden=self.hparams.normalize_hidden,
                )

        teacher_mask = input_ids.ne(pad_token_id)
        with torch.no_grad():
            outputs = self.teacher(
                input_ids,
                attention_mask=teacher_mask,
                encoder_outputs=(teacher_enc_outputs, ),
                decoder_input_ids=decoder_input_ids,
                lm_labels=labels,
                output_hidden_states=not self.different_base_models,
                return_dict=True,
            )
        dec_mask = decoder_input_ids.ne(pad_token_id)
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits, outputs.logits)
        if (not self.different_base_models) and self.alpha_hid > 0:  # Intermediate supervision of decoder hidden states
            tdec_hidden = outputs.decoder_hidden_states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
            )

self.different_base_models is initialized in __init__ in the same place as the earlier self.do_calc_hidden_loss as follows:

        self.different_base_models = True

        if hparams.student is None or hparams.teacher == hparams.student:
            self.different_base_models = False
            # Intermediate supervision: Decide which layers to supervise
            if hparams.supervise_forward:
                self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
                self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
            else:  # student layer should emulate hidden states of the teacher layer it was copied from
                self.e_matches = self.e_layer_ids
                self.d_matches = self.d_layer_ids

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes push that, I'll make some changes and get it merged.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed!

input_ids, attention_mask=src_mask, output_hidden_states=True
)
if self.do_calc_hidden_loss:
hid_loss_enc = self.calc_hidden_loss(
src_mask,
enc_hidden_state,
teacher_enc_hid,
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)

teacher_mask = input_ids.ne(pad_token_id)
with torch.no_grad():
outputs = self.teacher(
input_ids,
attention_mask=src_mask,
encoder_outputs=(enc_outputs,),
attention_mask=teacher_mask,
encoder_outputs=(teacher_enc_outputs,),
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
Expand All @@ -184,7 +219,7 @@ def zero_tensor():
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
if self.do_calc_hidden_loss and self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
hid_loss_dec = self.calc_hidden_loss(
dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
)
Expand Down Expand Up @@ -215,10 +250,16 @@ def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, no


def add_distill_args(parser):
# NOTE: if --student argument was specified and the teacher and student base models
# are different, the models still have to have the same tokenizer, specified by
# --tokenizer_name. So for e.g., you can distill from t5_large to t5_small but not
# from bart to t5. This s because if the tokenizers are different, the output space
# for the two models is also different and their logits are not comparable.
parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student", type=str, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
Expand Down
11 changes: 11 additions & 0 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"freeze_encoder": False,
"auto_scale_batch_size": False,
"overwrite_output_dir": False,
"student": None,
}


Expand All @@ -100,6 +101,7 @@ def _dump_articles(path: Path, articles: list):
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
T5_TINIER = "sshleifer/t5-tinier-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
Expand Down Expand Up @@ -226,6 +228,15 @@ def test_distill_t5(self):
)
self._test_distiller_cli(updates)

def test_distill_different_student_teacher_base_models(self):
updates = dict(
teacher=T5_TINY,
student=T5_TINIER,
model_name_or_path=T5_TINIER,
tokenizer_name=T5_TINIER,
)
self._test_distiller_cli(updates)

def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
Expand Down