-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fewshot.py
76 lines (54 loc) · 1.86 KB
/
fewshot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
import torch
# Check if MPS is available and set the device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)
example_indices = [50, 500]
dash_line = '-'.join('' for x in range(100))
model_name = 'google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Move the model to the correct device
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
def make_prompt(example_indices_full, example_index_to_summarize):
prompt = ''
for index in example_indices_full:
dialogue = dataset['test'][index]['dialogue']
summary = dataset['test'][index]['summary']
# The stop sequence '{summary}\n\n\n' is important for FLAN-T5.
# Other models may have their own preferred stop sequence.
prompt += f"""
Dialogue:
{dialogue}
What was going on?
{summary}
"""
dialogue = dataset['test'][example_index_to_summarize]['dialogue']
prompt += f"""
Dialogue:
{dialogue}
What was going on?
"""
return prompt
# Let's start few start config
example_indices_full = [50, 100]
example_index_to_summarize = 500
few_shot_prompt = make_prompt(example_indices_full, example_index_to_summarize)
print(few_shot_prompt)
summary = dataset['test'][example_index_to_summarize]['summary']
inputs = tokenizer(few_shot_prompt, return_tensors='pt').to(device)
output = tokenizer.decode(
model.generate(
inputs["input_ids"],
max_new_tokens=50,
)[0],
skip_special_tokens=True
)
print(dash_line)
print(f'BASELINE HUMAN SUMMARY:\n{summary}\n')
print(dash_line)
print(f'MODEL GENERATION - FEW SHOT:\n{output}')