diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 0ddad3cbf985..73715b9c8af1 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1354,8 +1354,14 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, ) num_steps = self.state.global_step - self._globalstep_last_logged seq_length = None + model_flops = None if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"): seq_length = getattr(self.model.config, "seq_length", None) + try: + model_flops = self.model.get_hardware_flops(seq_length=seq_length, recompute=self.args.recompute) + except NotImplementedError: + model_flops = None + logs.update( speed_metrics( "interval", @@ -1363,6 +1369,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, num_samples=total_train_batch_size * num_steps, num_steps=num_steps, seq_length=seq_length, + model_flops=model_flops, ) ) diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 86504648cc48..37ccd5b1f6c5 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -344,7 +344,7 @@ def total_processes_number(local_rank): return 1 -def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None): +def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_length=None, model_flops=None): """ Measure and return speed performance metrics. @@ -365,6 +365,11 @@ def speed_metrics(split, start_time, num_samples=None, num_steps=None, seq_lengt if seq_length is not None: tokens_per_second_per_device = samples_per_second * seq_length / paddle.distributed.get_world_size() result[f"{split}_tokens_per_second_per_device"] = round(tokens_per_second_per_device, 4) + if model_flops is not None: + result[f"{split}_hardware_tflops_per_device"] = round( + tokens_per_second_per_device * model_flops / seq_length / 2**40, 2 + ) + if num_steps is not None: steps_per_second = num_steps / runtime result[f"{split}_steps_per_second"] = round(steps_per_second, 4) diff --git a/paddlenlp/transformers/gemma/modeling.py b/paddlenlp/transformers/gemma/modeling.py index 1be5a2453d5c..dfb93345f25d 100644 --- a/paddlenlp/transformers/gemma/modeling.py +++ b/paddlenlp/transformers/gemma/modeling.py @@ -55,6 +55,7 @@ from .. import linear_utils from ..linear_utils import Linear from ..segment_parallel_utils import ReshardLayer +from ..utils import caculate_llm_flops from .configuration import ( GEMMA_PRETRAINED_INIT_CONFIGURATION, GEMMA_PRETRAINED_RESOURCE_FILES_MAP, @@ -1074,6 +1075,39 @@ def __init__(self, config: GemmaConfig): self.gradient_checkpointing = False + def get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=False, + ) + + def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=recompute, + recompute_granularity=self.config.recompute_granularity, + ) + def get_input_embeddings(self): return self.embed_tokens diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index a618600fc446..1a5bbd3ed698 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -53,6 +53,7 @@ TokenClassifierOutput, ) from ..model_utils import dy2st_nocheck_guard_context +from ..utils import caculate_llm_flops from .configuration import ( GPT_PRETRAINED_INIT_CONFIGURATION, GPT_PRETRAINED_RESOURCE_FILES_MAP, @@ -1105,6 +1106,39 @@ def __init__(self, config: GPTConfig): decoder_layers, ) + def get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=False, + ) + + def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=recompute, + recompute_granularity=self.config.recompute_granularity, + ) + def get_input_embeddings(self): return self.embeddings.word_embeddings diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 027e91eac95e..d1415d2e9565 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -70,6 +70,7 @@ def swiglu(x, y=None): from .. import linear_utils from ..linear_utils import Linear from ..segment_parallel_utils import ReshardLayer +from ..utils import caculate_llm_flops from .configuration import ( LLAMA_PRETRAINED_INIT_CONFIGURATION, LLAMA_PRETRAINED_RESOURCE_FILES_MAP, @@ -1468,6 +1469,39 @@ def __init__(self, config: LlamaConfig): self.gradient_checkpointing = False + def get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=False, + ) + + def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=recompute, + recompute_granularity=self.config.recompute_granularity, + ) + def get_input_embeddings(self): return self.embed_tokens diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index c15e1687c9f3..95d67ae788b2 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -1102,6 +1102,20 @@ def get_memory_footprint(self, return_buffers=True): mem = mem + mem_bufs return mem + def get_model_flops(self, *args, **kwargs): + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_model_flops() + + raise NotImplementedError(f"model of {type(base_model)} has not implemented the `get_model_flops`") + + def get_hardware_flops(self, *args, **kwargs): + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_hardware_flops() + + raise NotImplementedError(f"model of {type(base_model)} has not implemented the `get_hardware_flops`") + def get_input_embeddings(self) -> nn.Embedding: """get input embedding of model diff --git a/paddlenlp/transformers/qwen/modeling.py b/paddlenlp/transformers/qwen/modeling.py index fca8e55919d6..d97c2783382f 100755 --- a/paddlenlp/transformers/qwen/modeling.py +++ b/paddlenlp/transformers/qwen/modeling.py @@ -49,6 +49,7 @@ def swiglu(x, y=None): from .. import linear_utils from ..linear_utils import Linear from ..model_outputs import ModelOutput +from ..utils import caculate_llm_flops from .configuration import QWenConfig try: @@ -690,6 +691,39 @@ def __init__(self, config): ) self.ln_f = QWenRMSNorm(config) + def get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=False, + ) + + def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=recompute, + recompute_granularity=self.config.recompute_granularity, + ) + def get_input_embeddings(self): return self.wte diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index e016947cf3a9..74353ae27409 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -44,6 +44,7 @@ TokenClassifierOutput, ) from ..model_utils import PretrainedModel, register_base_model +from ..utils import caculate_llm_flops from .configuration import Qwen2Config try: @@ -914,6 +915,39 @@ def __init__(self, config: Qwen2Config): ) self.norm = Qwen2RMSNorm(config) + def get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=False, + ) + + def get_hardware_flops(self, batch_size=1, seq_length=None, recompute=False, **kwargs): + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=recompute, + recompute_granularity=self.config.recompute_granularity, + ) + def get_input_embeddings(self): return self.embed_tokens diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index d4c319575aca..3b12c30234f9 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -958,3 +958,46 @@ def __repr__(self): if self.err_buf: msg += f"stderr: {self.err}\n" return msg + + +def caculate_llm_flops( + hidden_size, + intermediate_size, + layer_num, + vocab_size, + batch_size=1, + seq_length=None, + recompute=False, + recompute_granularity=None, +): + + # TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf). + flops_per_transformer = 0 + flops_recompute_transformer = 0 + + # qkvo matmul + flops_qkvo_matmul = seq_length * hidden_size**2 * 4 + + # [b,s,h] [b,h,s] bs^2h + # [b,s,s] [b,s,h] bs^2h + # q_states * k_states + attn_weight * v_states + flops_core_attn = seq_length**2 * hidden_size * 2 + + # swiglu, matmul + dot + flops_ffn = seq_length * hidden_size * intermediate_size * 3 + seq_length * intermediate_size + + flops_per_transformer = flops_qkvo_matmul + flops_core_attn + flops_ffn + if recompute: + if recompute_granularity == "full": + flops_recompute_transformer = flops_per_transformer + if recompute_granularity == "full_attn": + flops_recompute_transformer = flops_qkvo_matmul + flops_core_attn + if recompute_granularity == "core_attn": + flops_recompute_transformer = flops_core_attn + + # final loggits + flops_loggits = seq_length * hidden_size * vocab_size + + # 2 for mul + add in matmul + # 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y + return 2 * batch_size * (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits)