-
Notifications
You must be signed in to change notification settings - Fork 191
/
observation.py
186 lines (156 loc) · 7.4 KB
/
observation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from dataclasses import dataclass
from typing import Dict, List, Any
import torch
from transformers import AutoTokenizer
from rl4lms.data_pools.text_generation_pool import Sample
from copy import deepcopy
@dataclass
class Observation:
# encoded input
prompt_or_input_encoded_pt: torch.tensor
# attention mask for the input
prompt_or_input_attention_mask_pt: torch.tensor
# input text
prompt_or_input_text: str
# encoded context
context_encoded_pt: torch.tensor
# attention mask for the context
context_attention_mask_pt: torch.tensor
# context text
context_text: str
# reference texts
target_or_reference_texts: List[str]
# concatenated input
input_encoded_pt: torch.tensor
input_attention_mask_pt: torch.tensor
# list of actions
action_history: List[str]
# other meta info
meta_info: Dict[str, Any]
def to_dict(self) -> Dict[str, torch.tensor]:
"""
For stable baselines (only return tensor items)
"""
dict_obs = {
"prompt_or_input_encoded_pt": self.prompt_or_input_encoded_pt.numpy().flatten(),
"prompt_or_input_attention_mask_pt": self.prompt_or_input_attention_mask_pt.numpy().flatten(),
"context_encoded_pt": self.context_encoded_pt.numpy().flatten(),
"context_attention_mask_pt": self.context_attention_mask_pt.numpy().flatten(),
"input_encoded_pt": self.input_encoded_pt.numpy().flatten(),
"input_attention_mask_pt": self.input_attention_mask_pt.numpy().flatten()
}
return dict_obs
@staticmethod
def _concat(prompt: torch.tensor, prompt_mask: torch.tensor,
context: torch.tensor, context_mask: torch.tensor,
pad_token: int):
prompt_ = prompt[:, prompt_mask.flatten().bool().tolist()]
context_ = context[:, context_mask.flatten().bool().tolist()]
actual_size = prompt_.shape[1] + context_.shape[1]
full_size = prompt.shape[1] + context.shape[1]
concatenated = torch.full(
(full_size,), fill_value=pad_token).reshape(1, -1)
concatenated_mask = torch.zeros((1, full_size)).int()
concatenated[:, full_size -
actual_size:] = torch.cat((prompt_, context_), dim=1)
concatenated_mask[:, full_size -
actual_size:] = 1
return concatenated, concatenated_mask
def update(self, action: int, tokenizer: AutoTokenizer) -> "Observation":
"""
Updates the observation using the given action
"""
# update the action history
current_action_history = deepcopy(self.action_history)
current_action_history.append(tokenizer._convert_id_to_token(action))
# get the current context
current_context = deepcopy(self.context_encoded_pt)
current_context_attention_mask = deepcopy(
self.context_attention_mask_pt)
# just shift the context (also the attention mask) to left by 1
current_context[:, 0:-1] = current_context[:, 1:].clone()
current_context_attention_mask[:, 0:-
1] = current_context_attention_mask[:, 1:].clone()
# add the action always at the end (assumes left padding)
current_context[:, -1] = action
current_context_attention_mask[:, -1] = 1
# decode the context
context_text = tokenizer.decode(
current_context.flatten(), skip_special_tokens=True)
# concatenate and still keep the left padding
input_encoded_pt, input_attention_mask_pt = Observation._concat(
self.prompt_or_input_encoded_pt, self.prompt_or_input_attention_mask_pt,
current_context, current_context_attention_mask,
tokenizer.pad_token_id)
# and create a new observation
obs = Observation(self.prompt_or_input_encoded_pt,
self.prompt_or_input_attention_mask_pt,
self.prompt_or_input_text,
current_context,
current_context_attention_mask,
context_text,
self.target_or_reference_texts,
input_encoded_pt,
input_attention_mask_pt,
current_action_history,
self.meta_info)
return obs
@ classmethod
def init_from_sample(cls, sample: Sample,
tokenizer: AutoTokenizer,
max_input_length: int,
max_context_length: int,
prompt_truncation_side: str,
context_start_token: int = None,
meta_info: Dict[str, Any] = None):
# encode the prompt text
# override truncation side for prompt
prev_truncation_side = tokenizer.truncation_side
tokenizer.truncation_side = prompt_truncation_side
prompt_outputs = tokenizer(sample.prompt_or_input_text,
padding="max_length",
max_length=max_input_length,
return_tensors="pt",
return_attention_mask=True,
truncation=True)
tokenizer.truncation_side = prev_truncation_side
# encode the context text
context_outputs = tokenizer("",
padding="max_length",
max_length=max_context_length,
return_tensors="pt",
return_attention_mask=True)
# for seq2seq models, context should be initialized to start token if provided
if context_start_token is not None:
context_outputs.input_ids[:, -1] = context_start_token
context_outputs.attention_mask[:, -1] = 1
# concatenate
input_encoded_pt, input_attention_mask_pt = Observation._concat(
prompt_outputs.input_ids, prompt_outputs.attention_mask,
context_outputs.input_ids, context_outputs.attention_mask,
tokenizer.pad_token_id)
obs = Observation(prompt_or_input_encoded_pt=prompt_outputs.input_ids,
prompt_or_input_attention_mask_pt=prompt_outputs.attention_mask,
prompt_or_input_text=sample.prompt_or_input_text,
context_encoded_pt=context_outputs.input_ids,
context_attention_mask_pt=context_outputs.attention_mask,
input_encoded_pt=input_encoded_pt,
input_attention_mask_pt=input_attention_mask_pt,
context_text="",
target_or_reference_texts=sample.references,
action_history=[],
meta_info=meta_info)
return obs
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
sample = Sample("1", "Hello, this is cool", ["it is good", "going well"])
obs = Observation.init_from_sample(
sample=sample,
tokenizer=tokenizer,
max_input_length=24,
max_context_length=24
)
updated_obs = obs.update(10, tokenizer)
updated_obs = updated_obs.update(11, tokenizer)