|
44 | 44 | from vllm.multimodal.processing import EncDecMultiModalProcessor |
45 | 45 | from vllm.outputs import (PoolingRequestOutput, RequestOutput, |
46 | 46 | RequestOutputFactory) |
| 47 | +from vllm.platforms import current_platform |
47 | 48 | from vllm.pooling_params import PoolingParams |
48 | 49 | from vllm.prompt_adapter.request import PromptAdapterRequest |
49 | 50 | from vllm.sampling_params import RequestOutputKind, SamplingParams |
@@ -1814,7 +1815,21 @@ def _get_stats(self, |
1814 | 1815 | # TPOTs. |
1815 | 1816 | latency = seq_group.get_last_token_latency() |
1816 | 1817 | # last_token_time is set only for the last step so take avg |
1817 | | - num_outputs = scheduler_outputs.num_lookahead_slots + 1 |
| 1818 | + if current_platform.is_tt(): |
| 1819 | + # for the current tt model runner, the number of steps |
| 1820 | + # executed is not always the same as the number of |
| 1821 | + # lookahead slots but rather the number of balance |
| 1822 | + # tokens left to be generated. |
| 1823 | + assert len( |
| 1824 | + seq_group.seqs |
| 1825 | + ) == 1, "Only one seq per group is allowed for TT" |
| 1826 | + total_tokens = seq_group.seqs[0].get_output_len() - 1 |
| 1827 | + max_steps = scheduler_outputs.num_lookahead_slots + 1 |
| 1828 | + num_outputs = (total_tokens % |
| 1829 | + max_steps if total_tokens % |
| 1830 | + max_steps != 0 else max_steps) |
| 1831 | + else: |
| 1832 | + num_outputs = scheduler_outputs.num_lookahead_slots + 1 |
1818 | 1833 | latency /= num_outputs |
1819 | 1834 | time_per_output_tokens_iter.append(latency) |
1820 | 1835 | if seq_group.state.current_step == 0: |
|
0 commit comments