Skip to content

Tokenzier special chars #197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 101 additions & 16 deletions tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@ def __init__(self, tokenizer_model_path):
self.unk_token = "<unk>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.unk_token_id = self.tokenizer.unk_id()
self.unk_token_id = self.tokenizer.unk_id() # is the same as pad token id...
self.eos_token_id = self.tokenizer.eos_id()
self.bos_token_id = self.tokenizer.bos_id()
self.pad_token_id = 0 # self.tokenizer.pad_id()
self.newline_token_id = 13

self.special_characters = [(self.bos_token, self.bos_token_id), (self.eos_token, self.eos_token_id), (self.unk_token, self.unk_token_id)] # for tokenzier encoding

# Encode string

def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False, add_eos = False):
def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False, add_eos = False, encode_special_characters = False):

if isinstance(text, list):

Expand Down Expand Up @@ -61,15 +63,40 @@ def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False,
else:

# text is a single string

ids = self.tokenizer.EncodeAsIds(text)
split_text = [text]

# look for special characters
if encode_special_characters:
for special_character, special_token_id in self.special_characters:
temp_text = []
for segment in split_text:
if isinstance(segment, str) and special_character in segment:
# for each special character, append the text before the special character, then append the special character ID, then the rest of the text
parts = segment.split(special_character)
new_parts = []
for i, part in enumerate(parts):
new_parts.append(part)
if i < len(parts) - 1: # add the special token id between parts, but not after the last part
new_parts.append(special_token_id)
temp_text.extend(new_parts)
else:
temp_text.append(segment)
split_text = temp_text

ids = []

for text_chunk in split_text:
if isinstance(text_chunk, str):
ids += self.tokenizer.EncodeAsIds(text_chunk)
else:
ids.append(text_chunk)

# pad bos and eos

if add_bos:
ids = [self.bos_token_id] + ids
ids = [self.bos_token_id] + ids
if add_eos:
ids = ids + [self.eos_token_id]
ids = ids + [self.eos_token_id]

stacked_ids = torch.tensor(ids).unsqueeze(0)

Expand All @@ -78,25 +105,83 @@ def encode(self, text, return_mask = False, max_seq_len = 2048, add_bos = False,
else:
return stacked_ids

def decode(self, ids):
def decode(self, ids, decode_special_characters=False):

special_ids = {id_: char for char, id_ in self.special_characters} # create a lookup dictionary

if ids.dim() > 1:

texts = []
for i in range(ids.shape[0]):
seq = ids[i].tolist()
seq = [t for t in seq if t != self.pad_token_id]
if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)]
texts.append(self.tokenizer.Decode(seq))

if decode_special_characters:
text_parts = []
normal_ids = [] # list of lists
current_normal_ids = [] # current list of normal IDs
for idx, id_ in enumerate(seq):
if id_ in special_ids:
# Save the current list of normal IDs, then start a new one
normal_ids.append(current_normal_ids)
current_normal_ids = []
# Store special token as a string
text_parts.append(special_ids[id_])
else:
current_normal_ids.append(id_)
normal_ids.append(current_normal_ids) # save the last segment of normal IDs

decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids]
for idx, decoded_segment in enumerate(decoded_segments):
text_parts.insert(2*idx, decoded_segment)

texts.append("".join(text_parts))
else:
if self.eos_token_id in seq: # to not mess up special char decoding
seq = seq[:seq.index(self.eos_token_id)]

return texts

else:

ids = ids.tolist()
text = self.tokenizer.Decode(ids)
return text

def num_tokens(self, text):
if decode_special_characters:

text_parts = []
normal_ids = [] # list of lists
current_normal_ids = [] # current list of normal IDs
for idx, id_ in enumerate(ids):
if id_ in special_ids:
# Save the current list of normal IDs, then start a new one
normal_ids.append(current_normal_ids)
current_normal_ids = []
# Store special token as a string
text_parts.append(special_ids[id_])
else:
current_normal_ids.append(id_)
normal_ids.append(current_normal_ids) # save the last segment of normal IDs

decoded_segments = [self.tokenizer.Decode(segment) for segment in normal_ids]
for idx, decoded_segment in enumerate(decoded_segments):
text_parts.insert(2*idx, decoded_segment)

text = "".join(text_parts)

else:

text = self.tokenizer.Decode(ids)

return text

ids = self.tokenizer.Encode(text)
return len(ids)
def num_tokens(self, text, encode_special_characters = False):

if encode_special_characters:

ids = self.encode(text, encode_special_characters = True)
return ids.size(1)

else:

ids = self.tokenizer.Encode(text)
return len(ids)