-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathindex.html
More file actions
792 lines (751 loc) · 44.4 KB
/
index.html
File metadata and controls
792 lines (751 loc) · 44.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<!-- Meta tags for social media banners, these should be filled in appropriatly as they are your "business card" -->
<!-- Replace the content tag with appropriate information -->
<meta name="description" content="MagicDec: Breaking the Latency-Throughput Tradeoff for Long Contexts with Speculative Decoding">
<meta property="og:title" content="Magicdec"/>
<meta property="og:description" content="MagicDec: Breaking the Latency-Throughput Tradeoff for Long Contexts with Speculative Decoding"/>
<meta property="og:url" content="https://github.com/Infini-AI-Lab/MagicDec/"/>
<!-- Path to banner image, should be in the path listed below. Optimal dimenssions are 1200X630-->
<meta property="og:image" content="static/images/icons/MagicDec.png"/>
<meta property="og:image:width" content="1200"/>
<meta property="og:image:height" content="630"/>
<meta name="twitter:title" content="MagicDec">
<meta name="twitter:description" content="MagicDec: Breaking the Latency-Throughput Tradeoff for Long Contexts with Speculative Decoding">
<!-- Path to banner image, should be in the path listed below. Optimal dimenssions are 1200X600-->
<meta name="twitter:image" content="static/images/icons/MagicDec.png">
<meta name="twitter:card" content="summary_large_image">
<!-- Keywords for your paper to be indexed by-->
<meta name="keywords" content="Speculative Decoding">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Speculative decoding for high-throughput long-context inference
</title>
<link rel="icon" type="image/x-icon" href="static/images/icons/MagicDec.png">
<link href="https://fonts.googleapis.com/css?family=Google+Sans|Noto+Sans|Castoro"
rel="stylesheet">
<link rel="stylesheet" href="static/css/bulma.min.css">
<link rel="stylesheet" href="static/css/bulma-carousel.min.css">
<link rel="stylesheet" href="static/css/bulma-slider.min.css">
<link rel="stylesheet" href="static/css/fontawesome.all.min.css">
<link rel="stylesheet"
href="https://cdn.jsdelivr.net/gh/jpswalsh/academicons@1/css/academicons.min.css">
<link rel="stylesheet" href="static/css/index.css">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script src="https://documentcloud.adobe.com/view-sdk/main.js"></script>
<script defer src="static/js/fontawesome.all.min.js"></script>
<script src="static/js/bulma-carousel.min.js"></script>
<script src="static/js/bulma-slider.min.js"></script>
<script src="static/js/index.js"></script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({tex2jax: {inlineMath: [['$','$'], ['\\(','\\)']]}});
</script>
<script type="text/javascript"
src="http://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML">
</script>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<style>
@font-face {
font-family: 'TriForceFont';
src: url('static/Triforce.ttf') format('truetype');
}
.custom-font {
font-family: 'TriForceFont', sans-serif !important;
font-size: 3.0rem;
}
body {
background-color: #f5f5f5; /* Adjust this to match your page's gray color */
}
.image-container {
background-color: #f5f5f5; /* Same as body background */
display: inline-block; /* Or 'block' depending on your layout needs */
}
.image-container img {
mix-blend-mode: multiply;
max-width: 100%;
height: auto;
}
.container.is-fluid {
margin-left: 15px;
margin-right: 15px;
max-width: none;
}
.hero .hero-body {
padding: 3rem 0;
}
.section {
padding: 3rem 0;
}
.column.is-full-width {
padding: 0 15px;
}
</style>
</head>
<body>
<!-- Section: Header Titlepage -->
<section class="hero">
<div class="hero-body">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths has-text-centered">
<img src="static/images/icons/specdec.png" alt="Magic Wand Icon" style="display: inline; height: 3rem; vertical-align: top;">
<h1 class="title is-2 publication-title" style="display: inline;">Speculative decoding for high-throughput long-context inference</h1>
<br><br>
<div class="is-size-5 publication-authors">
<span class="author-block"><a href="" target="_blank">Jian Chen</a><sup>*1</sup>,</span>
<span class="author-block"><a href="" target="_blank">Vashisth Tiwari</a><sup>*1</sup>,</span>
<span class="author-block"><a href="" target="_blank">Ranajoy Sadhukhan</a><sup>*1</sup>,</span>
<span class="author-block"><a href="https://dreaming-panda.github.io/" target="_blank">Zhuoming Chen</a><sup>1</sup>,</span>
<br>
<span class="author-block"><a href="" target="_blank">Jinyuan Shi</a><sup>2</sup></span>
<span class="author-block"><a href="" target="_blank">Ian En-Hsu Yen</a><sup>2</sup>,</span>
<span class="author-block"><a href="https://avnermay.github.io/" target="_blank">Avner May</a><sup>4</sup>,</span>
<span class="author-block"><a href="https://www.andrew.cmu.edu/user/beidic/" target="_blank">Beidi Chen</a><sup>1,3</sup></span>
</div>
<div class="is-size-5 publication-authors">
<span class="affliation">
<small>
<sup>1</sup>Carnegie Mellon University
<sup>2</sup>Moffett AI
<sup>3</sup>Meta AI (FAIR)
<sup>4</sup>Together AI
</small>
</span>
<span class="eql-cntrb">
<small><br><sup>*</sup>Indicates Equal Contribution</small>
</span>
</div>
<div class="column has-text-centered">
<span class="link-block">
<a href="https://arxiv.org/abs/2408.11049" target="_blank" class="external-link button is-normal is-rounded is-dark">
<span class="icon"><i class="ai ai-arxiv"></i></span>
<span>arXiv</span>
</a>
</span>
<span class="link-block">
<a href="https://github.com/Infini-AI-Lab/MagicDec/tree/main" target="_blank" class="external-link button is-normal is-rounded is-dark">
<span class="icon"><i class="fab fa-github"></i></span>
<span>Code</span>
</a>
</span>
<!-- <span class="link-block">
<a href="https://youtu.be/vRAaAyjr6Jo" target="_blank" class="external-link button is-normal is-rounded is-dark">
<span class="icon"><i class="fab fa-youtube"></i></span>
<span>Video</span>
</a>
</span> -->
</div>
</div>
</div>
</div>
</div>
</section>
<!-- Section: Paper abstract
<section class="section hero is-light">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<div class="content has-text-justified">
<p>
<strong style="font-weight: 900;color: #0f598a">TL;DR:</strong> We introduce <strong>MagicDec</strong>, which uses Speculative Decoding to improve both throughput and latency of LLM inference. Our work identifies the bottleneck shifts with increasing batch size and sequence length, and uses these insights to deploy Speculative Decoding more effectively for high throughput inference. It challenges the existing wisdom regarding the inefficacy of Speculative Decoding for large batch sizes. More interestingly, we observe an <strong>improvement in speedup with increasing batch size</strong> for moderate to long sequences. Our work theoretically motivates and empirically validates why Speculative Decoding is a potential solution for breaking throughput-latency tradeoff, when used wisely.
</p>
</div>
</div>
</div>
</div>
</section> -->
<!-- Section: Paper abstract -->
<section class="section hero is-light">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;">
<img src="static/images/icons/Llama.png" style="height: 43px; display: inline; vertical-align:text-top;"/>
Introduction
</h2>
<div class="content has-text-justified">
<p>
The amount of inference being performed with LLMs is growing dramatically across many different use cases, many of which utilize the ever-increasing context lengths supported by these models. Thus, maximizing the inference throughput of these models—including at long context—is becoming an increasingly important problem. Higher throughput enables lower price per token for consumers and lower carbon footprint per token. From a capability perspective, higher throughput at long context unlocks numerous applications such as information extraction from large sets of documents, synthetic data generation for LLM training/fine-tuning, extended user-assistant chats, and agentic workflows (which typically require numerous LLM calls per user request). These applications often involve processing very long input sequences (e.g., long documents or chat histories), requiring models to process thousands of tokens to deliver intelligent outputs. High throughput at long context is particularly technically challenging due to its huge memory requirements for the KV cache. Conventional wisdom (e.g., <a style="color: #209CEE" href="https://arxiv.org/pdf/2302.01318">Chen et al., 2023</a>; <a style="color: #209CEE" href="https://arxiv.org/pdf/2401.15077">Li et al., 2024</a>; <a style="color: #209CEE" href="https://arxiv.org/pdf/2401.15077">Liu et al., 2024</a>) is that in the high-throughput regime (i.e., large batch sizes), speculative decoding—which leverages underutilized GPU compute during memory-bound decoding—does not make sense, because decoding will be compute-bound and the GPUs will thus be fully utilized. Surprisingly, we show analytically and empirically that for large batch sizes, if the input sequences are long enough, decoding once again becomes memory-bound due to the large size of the KV cache. Building on this key observation, we demonstrate that speculative decoding can increase throughput and latency by up to <strong>2x on 8 A100s in this large-batch, long-context setting</strong>.
</p>
<p>
In this blogpost, we first do a deep dive into the forward pass time of a single transformer layer during autoregressive decoding, and provide a simple equation—which we validate empirically—that describes when the forward pass will be memory-bound for a given hardware. More specifically, we analyze the fraction of the forward pass time that is taken by loading the KV cache. This analysis clearly shows that even for very large batch sizes, the layer will be memory-bound during decoding whenever the context length exceeds a threshold.
</p>
<p>
After presenting the above analysis, we describe how we can use speculative decoding to increase throughput in the long-context and large batch regime. In particular, we propose two algorithmic innovations:
<br>
<div className="disp">
<ol>
<li>
<a style="color: #209CEE" href="https://arxiv.org/pdf/2408.11049">MagicDec</a>: Taking advantage of the fact that the bottleneck during decoding at large batch + long context is loading the KV cache, MagicDec uses a fixed context window in the draft model to make the draft model many times faster than the target model (since the draft KV cache size is fixed). Furthermore, because in this regime loading the target model parameters is no longer the bottleneck, we can afford to use a very large and powerful draft model—we can even use the full target model as the draft model, as long as it uses a fixed context window. Based on these insights, MagicDec combines ideas from TriForce and StreamingLLM—as the draft model, it uses a StreamingLLM draft model (using sliding window attention + attention sink) with staged speculative decoding for further speedups during drafting. Intriguingly, in this regime, <strong>we get larger speedups the higher the batch size!</strong>
<li>
<a style="color: #209CEE" target="_blank">Adaptive Sequoia trees</a>: Leveraging our observation that there is a sequence length threshold above which decoding becomes memory bound—and that it becomes increasingly memory bound for even longer sequence lengths—we propose choosing the amount of speculation as a function of the sequence length (longer sequence length -> more speculated tokens). We leverage the Sequoia algorithm (see our <a style="color: #209CEE" href="https://arxiv.org/abs/2402.12374">paper</a>, <a style="color: #209CEE" href="https://www.together.ai/blog/sequoia">blog</a>) to determine the tree structure for the speculated tokens that maximizes the expected number of generated tokens.
</li>
</ol>
</div>
<br>
We now jump into our deep dive of a single transformer layer.
</p>
</div>
</div>
</div>
</div>
</section>
<!-- <p>
To answer this question, we revisit the efficacies of Speculative Decoding through an analytical approach. Existing works (e.g., <a style="color: #209CEE" href="https://arxiv.org/abs/2406.14066">Liu et al, 2024</a>; <a style="color: #209CEE" href="https://arxiv.org/abs/2310.18813">Su et al, 2023</a>) have claimed that although useful for small batch sizes, Speculative Decoding can become counter-productive for serving large batches. This inefficiency stems from two primary factors:
<ol>
<li>For small batch sizes, Speculative Decoding tries to use the underutilized compute resources through parallel verification of speculated tokens, but for large batch sizes, the compute resources are already well-utilized, diminishing the potential benefits.</li>
<li>If the draft model is not capable enough, the low acceptance rate can cause the target model to spend more time verifying tokens that are ultimately rejected.</li>
</ol>
</p> -->
<section class="section hero is-light">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;">
<img src="static/images/icons/deep-dive.png" style="height: 43px; display: inline; vertical-align:text-top;"/>
Deep dive: When is decoding for a single transformer layer dominated by loading the KV cache?
</h2>
<div class="content has-text-justified">
<p>
Here, we analyze when the decoding forward pass time of a single transformer layer is dominated by loading the KV cache. We show that as the context length and batch size increase, most of the time is spent on loading the KV cache.
<br>
For this analysis, we split the operations during the forward pass into two types: operations involving model parameters, and operations involving the KV cache. For each type of operation, we compute the number of FLOPS as well as the amount of memory that must be communicated. We note that while the operations involving model parameters become compute-bound as the batch size increases (as their arithmetic intensity equals the batch size $b$), operations involving the KV cache are always memory-bound (as their arithmetic intensity is constant, because each sequence in the batch has its own KV cache). Because the memory taken by the KV cache grows linearly with both the batch size and the average sequence length, whereas the model parameter FLOPS are constant with respect to the sequence length, the forward pass time becomes dominated by the loading of the KV cache as the average sequence length increases.
</p>
<p>
Here, we will assume that we use a regular MLP, intermediate size=4*d, d=model dim, b=batch size, and n=current prefix length. We assume we are using GQA, where “g” corresponds to the ratio of query heads to key/value heads.
</p>
<!-- <style>
table {
border-collapse: collapse;
width: 100%;
max-width: 800px;
margin: 20px auto;
}
th, td {
border: 1px solid black;
padding: 8px;
text-align: center;
}
th {
background-color: #f2f2f2;
}
caption {
margin-bottom: 10px;
font-weight: bold;
}
</style> -->
<table >
<caption>Table 1: Memory and compute of a single transformer layer during decoding, split up in terms of operations with model parameters (MLP params, W_{Q,K,V,O}) and with the KV cache. 'g' corresponds to the memory reduction factor from GQA (g = num_attention_heads / num_key_value_heads).</caption>
<thead>
<tr>
<th></th>
<th>Model Params</th>
<th>KV cache</th>
</tr>
</thead>
<tbody>
<tr>
<td>Memory (bytes)</td>
<td>
<math xmlns="http://www.w3.org/1998/Math/MathML">
<mn>2</mn>
<mo>*</mo>
<mo>(</mo>
<mn>10</mn>
<msup>
<mi>d</mi>
<mn>2</mn>
</msup>
<mo>+</mo>
<mn>2</mn>
<msup>
<mi>d</mi>
<mn>2</mn>
</msup>
<mo>/</mo>
<mi>g</mi>
<mo>)</mo>
</math>
</td>
<td>
<math xmlns="http://www.w3.org/1998/Math/MathML">
<mn>2</mn>
<mo>*</mo>
<mn>2</mn>
<mi>b</mi>
<mi>n</mi>
<mi>d</mi>
<mo>/</mo>
<mi>g</mi>
</math>
</td>
</tr>
<tr>
<td>Compute (FLOPs)</td>
<td>
<math xmlns="http://www.w3.org/1998/Math/MathML">
<mn>2</mn>
<mi>b</mi>
<mo>*</mo>
<mo>(</mo>
<mn>10</mn>
<msup>
<mi>d</mi>
<mn>2</mn>
</msup>
<mo>+</mo>
<mn>2</mn>
<msup>
<mi>d</mi>
<mn>2</mn>
</msup>
<mo>/</mo>
<mi>g</mi>
<mo>)</mo>
</math>
</td>
<td>
<math xmlns="http://www.w3.org/1998/Math/MathML">
<mn>2</mn>
<mo>*</mo>
<mn>2</mn>
<mi>b</mi>
<mi>n</mi>
<mi>d</mi>
</math>
</td>
</tr>
<tr>
<td>Arithmetic intensity</td>
<td>
<math xmlns="http://www.w3.org/1998/Math/MathML">
<mi>b</mi>
</math>
</td>
<td>
<math xmlns="http://www.w3.org/1998/Math/MathML">
<mi>g</mi>
</math>
</td>
</tr>
</tbody>
</table>
<!-- <div style="display: flex; flex-wrap: wrap; justify-content: center;">
<div style="width: 45%; min-width: 150px; margin: 5px;">
<canvas id="chart1"></canvas>
</div>
<div style="width: 45%; min-width: 150px; margin: 5px;">
<canvas id="chart2"></canvas>
</div>
</div> -->
<!-- <script src="static/js/plots/throughput_latency_smaller.js"></script> -->
<br>
<p>
<!-- We note that our sparse KV cache based drafting is absolutely key in high batch size and sequence length regime. To retain a high acceptance rate even for longer contexts, we rely on a simple but effective KV sparsification strategy called <a style="color: #209CEE" href="https://arxiv.org/abs/2309.17453">StreamingLLM</a>. -->
From this table, it is easy to see that for large enough sequence length n (and batch size b), the time to load the KV cache will far exceed the operations involving the model parameters, regardless of whether those operations are compute bound or memory-bound.
<br>
In <a style="color: #0b3c5d" href="#figure-1">Figure 1</a> we empirically validate that loading the KV cache dominates the forward pass time for a transformer layer, as the sequence length and batch size increase. In particular, we plot the fraction of decode time taken by the operations over the KV cache for a transformer layer with a model dimension of 1024. As you can see, as the sequence length increases, the empirical fraction approaches 1, and it approaches 1 more quickly for larger batch size. This result was quite exciting and surprising to us—counterintuitively, in the long-context regime, a larger batch size results in decoding <strong>being more memory bound</strong>, instead of the other way around. The communities focus on short/medium context may have resulted in this fact being overlooked until now.
<div id="figure-1" class="image-container" style="display: flex; flex-direction: column; align-items: center;">
<div style="display: flex; justify-content: space-around; width: 100%;">
<img src="static/images/frac_fwd_pass.png" alt="Fraction of KV load time in Fwd pass" style="height: 500px; width: 600px; margin: 0 10px;" />
</div>
<figcaption style="margin-top: 10px; text-align: center;">
<strong>Figure 1: Fraction of the decode forward pass time of a transformer layer (model dimension 1024) taken by loading the KV cache, for different batch sizes and sequence lengths, on an H100.
</strong>
</figcaption>
</div>
</p>
</div>
</div>
</div>
</div>
</section>
<!-- Section: Motivation -->
<section class="section hero is-light">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;">
<img src="static/images/icons/Idea.png" style="height: 50px; display: inline; vertical-align: middle;"/>
Enter speculative decoding
</h2>
<div class="content has-text-justified">
<!-- <p>
Speculative decoding leverages underutilized GPU compute during memory-bound autoregressive decoding. However, <a href="https://arxiv.org/pdf/2406.14066" rel="external nofollow noopener" target="_blank" style="color: #209CEE;">prior research</a> has shown that Speculative Decoding becomes less effective as batch sizes increase and exhaust the available compute. These challenges have led to Speculative Decoding being avoided in processing large batches. However, we show that for long sequences, a shift in LLM operation bottleneck allows Speculative Decoding to become more effective with larger batch sizes.
</p> -->
<p>
Based on the above observations, we propose using speculative decoding to improve LLM throughput and latency during decoding in the large batch + long context regime. Intuitively, because the KV cache operations are memory-bound and dominate the compute time in this regime, there is idle compute that we can utilize with speculative decoding. More specifically, we can show that the verification time (T<sub>verify</sub>) during speculative decoding (when verifying L tokens) will be quite similar to the regular decode time (T<sub>decode</sub>), because the operations involving the KV cache will remain memory bound as L increases (and therefore will take the same amount of time). Although the time for the operations involving the model parameters can increase by a factor of L, the total time will not increase very much in the cases where the KV cache loading dominated the decode time. Therefore, as long as our time to speculate these L tokens (T<sub>draft</sub>) is relatively fast, and we have a high enough acceptance rate, we will attain speedups from using speculative decoding (see speedup equation below).
<br>
<div class="equation-container">
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mrow>
<mi>speedup</mi>
<mo>=</mo>
<mfrac>
<mrow>
<mtext>Expected number of generated tokens</mtext>
</mrow>
<mrow>
<mfrac>
<msub>
<mi>T</mi>
<mtext>verify</mtext>
</msub>
<msub>
<mi>T</mi>
<mtext>decode</mtext>
</msub>
</mfrac>
<mo>+</mo>
<mfrac>
<msub>
<mi>T</mi>
<mtext>draft</mtext>
</msub>
<msub>
<mi>T</mi>
<mtext>decode</mtext>
</msub>
</mfrac>
</mrow>
</mfrac>
</mrow>
</math>
</div>
<br>
In <a style="color: #0b3c5d" href="#figure-2">Figure 2</a>, we show that for large sequence lengths, $T_{verify}/$T_{decode}$ approaches 1, which implies that speculative decoding can give meaningful speedups.
<br> <br>
<div id="figure-2" class="image-container" style="display: flex; flex-direction: column; align-items: center;">
<div style="display: flex; justify-content: space-around; width: 100%;">
<img src="static/images/verify_2_autoreg.png" alt="verification to autoreg" style="height: 500px; width: 600px; margin: 0 10px;" />
</div>
<figcaption style="margin-top: 10px; text-align: center;">
<strong>Figure 2: T<sub>verify</sub>/T<sub>decode</sub> for various batch sizes, as a function of sequence length, on an H100.
</strong>
</figcaption>
</div>
<br>
We will now detail our two algorithmic innovations—MagicDec and adaptive Sequoia trees—related to performing speculative decoding in this high-throughput regime.
</p>
</div>
</div>
</div>
</div>
</section>
<section class="section hero is-light">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;">
<img src="static/images/icons/MagicDec.png" style="height: 50px; display: inline; vertical-align: middle;"/>
MagicDec
</h2>
<div class="content has-text-justified">
<p>
A low draft-to-verify cost ratio is ideal for speculative decoding. In the low-latency regime in which speculative decoding is normally applied (i.e., low batch size), the bottleneck during decoding is the time to load the target model parameters—therefore, using a small draft model is generally the key to attaining a low draft to verify ratio. However, in the high throughput regime we are interested in here, the bottleneck is loading the target model KV cache. This shift in bottlenecks opens up the possibility of using better strategies for drafting. In particular, we can afford to use a larger and more powerful target model as long as its KV cache is kept small.
</p>
<p>
Thus, we propose using self-speculation, where the target model is used as the draft model, but with limited context size. More specifically, we use StreamingLLM, which uses sliding window attention combined with an “attention sink” (allows attending over the first token) to limit the size of the KV cache. While the draft cost increases with larger batch sizes mainly due to increased computation time, the verification cost rises even more due to the greater KV loading time. This makes the draft-to-target cost ratio decrease with increasing batch size (see <a style="color: #0b3c5d" href="#figure-3">Figure 3</a>), surprisingly making speculative decoding more effective for larger batch sizes. To further speed up the drafting process, we can use <a style="color: #209CEE" href="https://arxiv.org/pdf/2404.11912">staged speculative decoding</a>, similarly to <a style="color: #209CEE" href="https://arxiv.org/pdf/2404.11912">TriForce</a>.
<br> <br>
<div id="figure-3" class="image-container" style="display: flex; flex-direction: column; align-items: center;">
<div style="display: flex; justify-content: space-around; width: 100%;">
<img src="static/images/draft_2_target.png" alt="draft to target" style="height: 500px; width: 1000px; margin: 0 10px;" />
</div>
<figcaption style="margin-top: 10px; text-align: center;">
<strong>Figure 3: Theoretical Ratio of self-speculation drafting time (StreamingLLM budget=256) vs. verification time, as a function of batch size (left: LLaMA-2-7B, right: LLaMA-3.1-8B), on 8xA100.
</strong>
</figcaption>
</div>
</p>
In Table 2, we demonstrate preliminary results attaining speedups of up to 2x for LLaMA-2-7B-32K and 1.84x for LLaMA-3.1-8B on 8 A100 GPUs.
<p>
<style>
#table1 {
border-collapse: collapse;
width: 100%;
max-width: 800px;
margin: 20px auto;
}
th, td {
border: 1px solid black;
padding: 8px;
text-align: center;
}
th {
background-color: #f2f2f2;
}
#table1 .bold-border {
border-bottom: 3px solid black;
}
caption {
margin-bottom: 10px;
font-weight: bold;
}
</style>
<table id="table1">
<caption>Table 2: End-to-end Speculative Decoding Speedups for Various Target-Draft pairs on 8xA100s.</caption>
<thead>
<tr>
<th>Target</th>
<th>Draft</th>
<th>Prefill</th>
<th>Batch-size</th>
<th>Optimal spec len</th>
<th>Speedup</th>
</tr>
</thead>
<tbody>
<tr>
<td rowspan="4">Llama2-7b-32k</td>
<td>TinyLlama-1.1B</td>
<td>8000</td>
<td>32</td>
<td>3</td>
<td>1.29</td>
</tr>
<tr>
<td>TinyLlama-1.1B</td>
<td>8000</td>
<td>64</td>
<td>3</td>
<td>1.57</td>
</tr>
<tr>
<td>TinyLlama-1.1B</td>
<td>8000</td>
<td>128</td>
<td>4</td>
<td>1.66</td>
</tr>
<tr class="bold-border">
<td>TinyLlama-1.1B</td>
<td>32000</td>
<td>32</td>
<td>4</td>
<td>1.91</td>
</tr>
<tr>
<td rowspan="4">Llama2-7b-32k</td>
<td>Self-spec</td>
<td>8000</td>
<td>32</td>
<td>3</td>
<td>1.18</td>
</tr>
<tr>
<td>Self-spec</td>
<td>8000</td>
<td>64</td>
<td>3</td>
<td>1.48</td>
</tr>
<tr>
<td>Self-spec</td>
<td>8000</td>
<td>128</td>
<td>4</td>
<td>1.63</td>
</tr>
<tr class="bold-border">
<td>Self-spec</td>
<td>32000</td>
<td>32</td>
<td>4</td>
<td>2.00</td>
</tr>
<tr>
<td rowspan="4">Llama3.1-8b</td>
<td>Self-spec</td>
<td>32000</td>
<td>32</td>
<td>3</td>
<td>1.22</td>
</tr>
<tr>
<td>Self-spec</td>
<td>32000</td>
<td>64</td>
<td>3</td>
<td>1.38</td>
</tr>
<tr>
<td>Self-spec</td>
<td>32000</td>
<td>128</td>
<td>4</td>
<td>1.47</td>
</tr>
<tr>
<td>Self-spec</td>
<td>100000</td>
<td>32</td>
<td>5</td>
<td>1.84</td>
</tr>
</tbody>
</table>
For more details about this work, and additional results, please refer to our <a style="color: #209CEE" href="https://arxiv.org/abs/2408.11049">paper</a>.
</p>
</div>
</div>
</div>
</div>
</section>
<section class="section hero is-light">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;">
<img src="static/images/icons/sequoia.png" style="height: 50px; display: inline; vertical-align: middle;"/>
Adaptive Sequoia Trees
</h2>
<div class="content has-text-justified">
<p>
When we do speculative decoding with a tree of size L, we multiply the total number of flops by L+1 (because the new token generated by the target model, as well as the L speculated tokens, need to be processed by the target model), but keep the amount of memory that needs to be transported constant. Therefore, the flops/memory ratio R is simply multiplied by (L+1). Based on this observation, one simple approach would be to use the equation for R to find the largest value of L for which verification remains memory-bound, for each context-length. However, this approach is a bit coarse, as it ignores the cost of drafting the tree, as well as the marginal gain of increasing the size of the tree.
</p>
<p>
Therefore, we propose to refine the above approach by explicitly searching for the tree size which maximizes a speedup equation, for each context length. Similar to section 3.3.1 of <a style="color: #209CEE" href="https://arxiv.org/abs/2402.12374">Sequoia</a> paper, we can express speedup as follows (let b=batch size, n=sequence length, L=tree size, D=tree depth, G(L,D) = expected number of generated tokens, and T_model=forward pass time):
</p>
<div class="equation-container">
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block">
<mrow>
<mi>speedup</mi>
<mrow>
<mo>(</mo>
<mi>b</mi>
<mo>,</mo>
<mi>n</mi>
<mo>,</mo>
<mi>L</mi>
<mo>,</mo>
<mi>D</mi>
<mo>)</mo>
</mrow>
<mo>=</mo>
<mfrac>
<mrow>
<mi>G</mi>
<mo>(</mo>
<mi>L</mi>
<mo>,</mo>
<mi>D</mi>
<mo>)</mo>
</mrow>
<mrow>
<mfrac>
<mrow>
<msub>
<mi>T</mi>
<mtext>target</mtext>
</msub>
<mo>(</mo>
<mi>b</mi>
<mo>,</mo>
<mi>n</mi>
<mo>,</mo>
<mi>L</mi>
<mo>)</mo>
<mo>+</mo>
<mi>D</mi>
<mo>⋅</mo>
<msub>
<mi>T</mi>
<mtext>draft</mtext>
</msub>
<mo>(</mo>
<mi>b</mi>
<mo>,</mo>
<mi>n</mi>
<mo>,</mo>
<mn>1</mn>
<mo>)</mo>
</mrow>
<mrow>
<msub>
<mi>T</mi>
<mtext>target</mtext>
</msub>
<mo>(</mo>
<mi>b</mi>
<mo>,</mo>
<mi>n</mi>
<mo>,</mo>
<mn>1</mn>
<mo>)</mo>
</mrow>
</mfrac>
</mrow>
</mfrac>
</mrow>
</math>
</div>
<br>
<p>
For G(L, D), we can find the maximal expected number of generated tokens for a Sequoia tree of size L and depth D. For T<sub>model</sub>(b, n, L), we can just measure forward pass times for the target/draft models for many combinations of b, n, L, and perhaps fit these results with a parametric function.
</p>
<p>
Please be on the lookout for our forthcoming paper, which combines adaptive Sequoia trees with a highly-optimized pipeline parallel FP8 system, designed to maximize throughput.
</p>
</div>
</div>
</div>
</div>
</section>
<!-- Section: Conclusion and Future Work -->
<section class="section hero is-light">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;">
<img src="static/images/icons/Telescope.png" style="height: 50px; display: inline; vertical-align: middle;"/>
Conclusion and Future Work
</h2>
<div class="content has-text-justified">
<p>
This work reassesses the trade-off between throughput and latency in long-context scenarios. We demonstrate that speculative decoding can enhance throughput, reduce latency, and maintain accuracy. Our theoretical and empirical analysis reveals that as the sequence length and batch size increase, bottlenecks shift from being compute-bound to memory-bound. This shift enables effective use of speculative decoding for longer sequences, even with large batch sizes, achieving up to 2x speedup for LLaMA-2-7B-32K and 1.84x for LLaMA-3.1-8B on 8 A100 GPUs. These results highlight the importance of integrating speculative decoding into throughput optimization systems as long-context workloads become more prevalent.
</p>
</div>
<div class="has-text-centered">
<img src="static/images/icons/specdec.png" alt="<i>TriForce</i>" width="200" height="200" />
</div>
</div>
</div>
</div>
</section>
<!-- Section: References -->
<section class="section" id="BibTeX">
<!-- <div class="container is-max-desktop content"> -->
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title">BibTeX</h2>
<pre><code>@misc{chen2024magicdecbreakinglatencythroughputtradeoff,
title={MagicDec: Breaking the Latency-Throughput Tradeoff for Long Context Generation with Speculative Decoding},
author={Jian Chen and Vashisth Tiwari and Ranajoy Sadhukhan and Zhuoming Chen and Jinyuan Shi and Ian En-Hsu Yen and Beidi Chen},
year={2024},
eprint={2408.11049},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2408.11049},
}</code></pre>
</div>
</div>
</div>
</section>
<footer class="footer">
<div class="container is-fluid">
<div class="columns is-centered">
<div class="column is-four-fifths">
<div class="content">
<p>
This page was built using the <a href="https://github.com/eliahuhorwitz/Academic-project-page-template" target="_blank">Academic Project Page Template</a> which was adopted from the <a href="https://nerfies.github.io" target="_blank">Nerfies</a> project page.
You are free to borrow the of this website, we just ask that you link back to this page in the footer. <br> This website is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/" target="_blank">Creative
Commons Attribution-ShareAlike 4.0 International License</a>. The icons are created by GPT4.
</p>
</div>
</div>
</div>
</div>
</footer>