|
2 | 2 | from text_generation_server.generator import NeuronGenerator
|
3 | 3 | from text_generation_server.pb.generate_pb2 import Batch
|
4 | 4 |
|
| 5 | +import torch |
5 | 6 |
|
6 | 7 | def test_decode(neuron_model_config):
|
7 | 8 | """Verify that a decoding for a single request generates the expected output."""
|
8 | 9 | config_name = neuron_model_config["name"]
|
9 | 10 | neuron_model_path = neuron_model_config["neuron_model_path"]
|
10 | 11 | generator = NeuronGenerator.from_pretrained(neuron_model_path)
|
11 |
| - for do_sample in [True, False]: |
| 12 | + for do_sample in [False]: |
12 | 13 | mode = "sample" if do_sample else "greedy"
|
13 | 14 | print(f"{config_name}[{mode}]")
|
14 | 15 | _test_decode(config_name, generator, do_sample)
|
15 | 16 | generator.clear()
|
16 | 17 |
|
| 18 | +def sample_greedy(logits): |
| 19 | + next_logits = logits[:, -1] |
| 20 | + next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() |
| 21 | + return next_token_id |
| 22 | + |
| 23 | +def manual_greedy(generator: NeuronGenerator, input_text: str, max_new_tokens: int): |
| 24 | + |
| 25 | + model = generator.model |
| 26 | + tokenizer = generator.tokenizer |
| 27 | + input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
| 28 | + attention_mask = torch.ones_like(input_ids) |
| 29 | + seq_ids = torch.tensor([3], dtype=torch.int64) |
| 30 | + sampling_params = torch.ones([1, 3], device=model.device) |
| 31 | + |
| 32 | + model_inputs = model.prepare_inputs_for_prefill( |
| 33 | + input_ids, |
| 34 | + attention_mask=attention_mask, |
| 35 | + seq_ids=seq_ids, |
| 36 | + sampling_params=sampling_params, |
| 37 | + ) |
| 38 | + next_token = model(**model_inputs)[0].expand(1, -1) |
| 39 | + output_tokens = next_token.clone() |
| 40 | + |
| 41 | + for _ in range(max_new_tokens - 1): |
| 42 | + attention_mask = torch.cat([attention_mask, torch.ones([1, 1], device=model.device, dtype=torch.int64)], dim=1) |
| 43 | + model_inputs = model.prepare_inputs_for_decode( |
| 44 | + next_token, |
| 45 | + attention_mask=attention_mask, |
| 46 | + seq_ids=seq_ids, |
| 47 | + sampling_params=sampling_params, |
| 48 | + ) |
| 49 | + next_token = model(**model_inputs)[0].expand(1, -1) |
| 50 | + output_tokens = torch.cat([output_tokens, next_token], dim=1) |
| 51 | + |
| 52 | + return torch.cat([input_ids, output_tokens], dim=1) |
| 53 | + |
| 54 | + |
| 55 | +def manual_greedy_dbg(generator: NeuronGenerator, input_text: str, max_new_tokens: int): |
| 56 | + request = create_request( |
| 57 | + id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=True |
| 58 | + ) |
| 59 | + max_length = generator.model.neuron_config.sequence_length |
| 60 | + batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length) |
| 61 | + generations, next_batch = generator.prefill(batch) |
| 62 | + next_token = generations[0].tokens.ids |
| 63 | + model = generator.model |
| 64 | + |
| 65 | + # output_tokens = next_token.clone() |
| 66 | + output_tokens = torch.tensor([next_token]) |
| 67 | + next_token = torch.tensor([next_token]) |
| 68 | + breakpoint() |
| 69 | + |
| 70 | + # this is to get attention mask (it should be ones(1, 17)) |
| 71 | + tokenizer = generator.tokenizer |
| 72 | + input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
| 73 | + attention_mask = torch.ones_like(input_ids) |
| 74 | + |
| 75 | + seq_ids = torch.tensor([3], dtype=torch.int64) |
| 76 | + sampling_params = torch.ones([1, 3], device=model.device) |
| 77 | + |
| 78 | + for _ in range(max_new_tokens - 1): |
| 79 | + attention_mask = torch.cat([attention_mask, torch.ones([1, 1], device=model.device, dtype=torch.int64)], dim=1) |
| 80 | + model_inputs = model.prepare_inputs_for_decode( |
| 81 | + next_token, |
| 82 | + attention_mask=attention_mask, |
| 83 | + seq_ids=seq_ids, |
| 84 | + sampling_params=sampling_params, |
| 85 | + ) |
| 86 | + next_token = model(**model_inputs)[0].expand(1, -1) |
| 87 | + output_tokens = torch.cat([output_tokens, next_token], dim=1) |
| 88 | + generator.clear() |
| 89 | + return torch.cat([input_ids, output_tokens], dim=1) |
17 | 90 |
|
18 | 91 | def _test_decode(config_name, generator, do_sample):
|
19 | 92 | input_text = (
|
20 | 93 | "It was a bright cold day in April, and the clocks were striking thirteen."
|
21 | 94 | )
|
22 | 95 | max_new_tokens = 20
|
| 96 | + |
| 97 | + # model = generator.model |
| 98 | + # input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
| 99 | + # greedy_output = model.generate(input_ids, max_new_tokens=max_new_tokens) |
| 100 | + # print("greedy_output", greedy_output) |
| 101 | + |
| 102 | + # manual_greedy_output = manual_greedy(generator, input_text, max_new_tokens) |
| 103 | + manual_greedy_output = manual_greedy_dbg(generator, input_text, max_new_tokens) |
| 104 | + print("manual_greedy_output", manual_greedy_output) |
| 105 | + |
23 | 106 | request = create_request(
|
24 | 107 | id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample
|
25 | 108 | )
|
26 | 109 | max_length = generator.model.neuron_config.sequence_length
|
27 | 110 | batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)
|
28 | 111 | generations, next_batch = generator.prefill(batch)
|
| 112 | + |
| 113 | + tokenizer = generator.tokenizer |
| 114 | + tokens = generations[0].tokens |
| 115 | + print(f"next_batch tokens: {tokens.ids} {tokens.texts}") |
29 | 116 | # We already generated one token: call decode max_new_tokens - 1 times
|
30 | 117 | for _ in range(max_new_tokens - 1):
|
31 | 118 | assert next_batch.size == 1
|
32 | 119 | assert next_batch.max_tokens == max_length
|
33 | 120 | assert len(generations) == 1
|
34 | 121 | assert len(generations[0].tokens.ids) == 1
|
35 | 122 | generations, next_batch = generator.decode([next_batch])
|
| 123 | + tokens = generations[0].tokens |
| 124 | + print(f"next_batch tokens: {tokens.ids} {tokens.texts}") |
36 | 125 | assert next_batch is None
|
37 | 126 | assert len(generations) == 1
|
38 | 127 | output = generations[0].generated_text
|
39 | 128 | assert output.generated_tokens == max_new_tokens
|
40 | 129 | assert output.finish_reason == 0
|
| 130 | + |
| 131 | + breakpoint() |
| 132 | + |
41 | 133 | if do_sample:
|
| 134 | + print(output.text) |
42 | 135 | expected_text = {
|
43 |
| - "llama": " I sat alone in the café", |
44 |
| - "qwen2": " The air was so still", |
45 |
| - "granite": "1984, George Orwell", |
| 136 | + "llama": " The world outside was grey and silent, except for the sound of people scurrying about, trying", |
| 137 | + "qwen2": ' Old Mr.和Mr.和Mr.的姓氏是"布伦瑞特"。', |
| 138 | + "granite": " Winston Smith, a low-ranking member of the ruling Party, works for the Min", |
46 | 139 | }[config_name]
|
47 | 140 | assert expected_text in output.text
|
48 | 141 | else:
|
49 | 142 | print(output.text)
|
| 143 | + manual_greedy_text = tokenizer.decode(manual_greedy_output[0]) |
| 144 | + print("manual_greedy_output", manual_greedy_text) |
50 | 145 | expected_text = {
|
51 | 146 | "llama": " The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility",
|
52 |
| - "qwen2": " I was sitting in my room, staring at the ceiling, when the door opened and in came a", |
| 147 | + "qwen2": " I was sitting in my room, staring at the clock, when a knock at the door. I", |
53 | 148 | "granite": "\n\nThis opening line from George Orwell's dystopian novel \"198",
|
54 | 149 | }[config_name]
|
55 | 150 | assert output.text == expected_text
|
0 commit comments