Skip to content
17 changes: 17 additions & 0 deletions benchmark/reference_speculative/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
### Download MT-Bench data

```sh
wget -O question.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/d04ce6453ae016d9e03626b679c07aa1388dcbee/fastchat/llm_judge/data/mt_bench/question.jsonl
```

### Benchmark Context QA

```sh
python3 bench_sglang.py --mode contextqa
```

### Benchmark MT-Bench

```sh
python3 bench_sglang.py --mode mtbench
```
119 changes: 119 additions & 0 deletions benchmark/reference_speculative/bench_sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import time
import datasets
import argparse
import json
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text


@sgl.function
def context_qa(s, refs, question):
s += sgl.system()
prompt = "Please answer a question according to the following references.\n"
for i, refs in enumerate(refs):
prompt += f"Ref {i}: {refs}\n"
prompt += "The questions is: " + question + "\n"
prompt += "Please provide a single-paragraph answer. "
prompt += "Focus on the provided references and the answer to the question. "
prompt += 'End your answer paragraph with the word "END"\n'
s += sgl.user(prompt)
s += sgl.assistant(sgl.gen("answer", stop="END"))


@sgl.function
def mt_bench(s, question_1, question_2):
s += sgl.system()
s += sgl.user(question_1)
s += sgl.assistant(sgl.gen("answer_1"))
s += sgl.user(question_2)
s += sgl.assistant(sgl.gen("answer_2"))


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num", type=int, default=100)
parser.add_argument(
"--mode", type=str, choices=["contextqa", "mtbench"], default="contextqa"
)
args = add_common_sglang_args_and_parse(parser)
backend = select_sglang_backend(args)

sgl.set_default_backend(backend)
completion_tokens = 0
decode_tokens = 0

if args.mode == "contextqa":
dataset = datasets.load_dataset(
"miracl/hagrid", split="dev", trust_remote_code=True
)
dataset = dataset.select(range(args.num))
arguments = [
{"refs": [q["text"] for q in d["quotes"]], "question": d["query"]}
for d in dataset
]
tic = time.time()
states = context_qa.run_batch(
arguments,
max_new_tokens=256,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
toc = time.time()
for s in states:
meta = s.get_meta_info("answer")
completion_tokens += meta["completion_tokens"]
decode_tokens += meta["decode_tokens"]
elif args.mode == "mtbench":
arguments = []
with open("question.jsonl", "r") as lines:
for line in lines:
obj = json.loads(line)
arguments.append(
{"question_1": obj["turns"][0], "question_2": obj["turns"][1]}
)
arguments = arguments[: args.num]
tic = time.time()
states = mt_bench.run_batch(
arguments,
temperature=0,
max_new_tokens=256,
num_threads=args.parallel,
progress_bar=True,
)
toc = time.time()
for s in states:
meta1 = s.get_meta_info("answer_1")
meta2 = s.get_meta_info("answer_2")
completion_tokens += meta1["completion_tokens"] + meta2["completion_tokens"]
decode_tokens += meta1["decode_tokens"] + meta2["decode_tokens"]
else:
raise ValueError(f"Invalid mode: {args.mode}")

print(f"Latency: {toc - tic:.3f}")

print(f"Completion tokens: {completion_tokens}")
print(f"Decode tokens: {decode_tokens}")

# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)

with open(args.result_file, "a") as fout:
value = {
"task": "reference_speculative",
"backend": args.backend,
"mode": args.mode,
"latency": round(toc - tic, 3),
"num_requests": args.num,
"completion_tokens": completion_tokens,
"decode_tokens": decode_tokens,
}
fout.write(json.dumps(value) + "\n")


if __name__ == "__main__":
main()
102 changes: 102 additions & 0 deletions examples/usage/reference_speculative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import sglang as sgl

ref_text = """\
The location of Hogwarts is Scotland, UK.
The headmaster of Hogwarts is Albus Dumbledore.
The potions teacher in Hogwarts is Severus Snape.
The transfiguration teacher in Hogwarts is Minerva McGonagall.
The herbology teacher in Hogwarts is Pomona Sprout.
The defense against the dark arts teacher in Hogwarts is Gilderoy Lockhart."""

ref_code = """\
```
import numpy as np
import matplotlib.pyplot as plt

# Calculate the average
average_throughput = np.mean(tokens_per_sec_arr)
print(f"Average Throughput: {average_throughput} tokens/sec")

# Plotting the histogram
plt.hist(tokens_per_sec_arr, bins=20, color='blue', edgecolor='black', alpha=0.7)
plt.title('Histogram of Throughput Values')
plt.xlabel('Tokens per Second')
plt.ylabel('Frequency')
plt.axvline(average_throughput, color='red', linestyle='dashed', linewidth=1)
plt.text(average_throughput*0.9, max(plt.ylim())*0.9, f'Average: {average_throughput:.2f}', color = 'red')
plt.show()
```"""


@sgl.function
def simple_qa(s, question, answer, ref_text):
s += "According to the reference text, answer the following question:\n"
s += "Reference text Begins.\n"
s += ref_text + "\n"
s += "Reference text Ends.\n"
s += question + "\n"
s += answer + sgl.gen("answer", stop="\n")
# s += answer + sgl.gen("answer", stop="\n", ref_text=ref_text)


@sgl.function
def code_modification(s, ref_code, question):
s += "Question: Below is a python code:\n"
s += ref_code + "\n"
s += question + "\n"
s += "Answer: Sure, here is the modified code:\n"
s += "```"
s += sgl.gen("code", ref_text=ref_code, max_tokens=1024, temperature=0, stop="```")


def main():
backend = sgl.RuntimeEndpoint("http://localhost:30000")
sgl.set_default_backend(backend)
arguments = [
{
"ref_text": ref_text,
"question": "Where is Hogwarts located?\n",
"answer": "The location of Hogwarts",
},
{
"ref_text": ref_text,
"question": "Who is the headmaster of Hogwarts?\n",
"answer": "The headmaster of Hogwarts",
},
{
"ref_text": ref_text,
"question": "Who is the potions teacher in Hogwarts?\n",
"answer": "The potions teacher in Hogwarts",
},
{
"ref_text": ref_text,
"question": "Who is the transfiguration teacher in Hogwarts?\n",
"answer": "The transfiguration teacher in Hogwarts",
},
{
"ref_text": ref_text,
"question": "Who is the herbology teacher in Hogwarts?\n",
"answer": "The herbology teacher in Hogwarts",
},
{
"ref_text": ref_text,
"question": "Who is the defense against the dark arts teacher in Hogwarts?\n",
"answer": "The defense against the dark arts teacher in Hogwarts",
},
]
states = simple_qa.run_batch(arguments, temperature=0)
for state in states:
print(state["answer"])
print("=" * 50)

state = code_modification.run(
ref_code=ref_code,
question="Can you please change x axis to start from 0?",
stream=True,
)
for chunk in state.text_iter():
print(chunk, end="", flush=True)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def gen(
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
ref_text: Optional[str] = None,
):
if choices:
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
Expand All @@ -91,6 +92,7 @@ def gen(
ignore_eos,
dtype,
regex,
ref_text,
)


Expand Down
1 change: 1 addition & 0 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ def _resolve_sampling_params(self, sampling_params):
"ignore_eos",
"dtype",
"regex",
"ref_text",
]:
value = getattr(sampling_params, item, None)
if value is not None:
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class SglSamplingParams:
dtype: Optional[str] = None
regex: Optional[str] = None

# for reference speculative decoding
ref_text: Optional[str] = None

def clone(self):
return SglSamplingParams(
self.max_new_tokens,
Expand Down Expand Up @@ -93,6 +96,7 @@ def to_srt_kwargs(self):
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
"ref_text": self.ref_text,
}


Expand Down Expand Up @@ -350,6 +354,7 @@ def __init__(
ignore_eos,
dtype,
regex,
ref_text,
):
super().__init__()
self.name = name
Expand All @@ -364,6 +369,7 @@ def __init__(
ignore_eos=ignore_eos,
dtype=dtype,
regex=regex,
ref_text=ref_text,
)

def __repr__(self):
Expand Down
15 changes: 5 additions & 10 deletions python/sglang/srt/layers/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,16 @@ def extend_attention_fwd(

def redundant_attention(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
):
total_token_num = k_buffer.shape[0]
B, H_Q, D = b_req_idx.shape[0], q_extend.shape[-2], q_extend.shape[-1]
B, H_Q, D = b_start_loc.shape[0], q_extend.shape[-2], q_extend.shape[-1]
q_buffer = torch.empty(
(total_token_num, H_Q, D), dtype=q_extend.dtype, device=q_extend.device
)
Expand Down Expand Up @@ -374,6 +370,9 @@ def test():
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()

########## TEST EXTEND ATTENTION ##########

extend_attention_fwd(
q_extend,
k_extend,
Expand All @@ -394,13 +393,9 @@ def test():

redundant_attention(
q_extend,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
Expand All @@ -410,7 +405,7 @@ def test():
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))

assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
assert torch.allclose(o_extend, o_redundant, rtol=1e-2, atol=1e-3)


if __name__ == "__main__":
Expand Down
Loading