|
| 1 | +# This file includes code which was modified from https://github.com/openai/gpt-2 |
| 2 | + |
| 3 | +import tensorflow as tf |
| 4 | +import os |
| 5 | +import json |
| 6 | +import regex as re |
| 7 | +from functools import lru_cache |
| 8 | +import requests |
| 9 | +import boto3 |
| 10 | + |
| 11 | + |
| 12 | +@lru_cache() |
| 13 | +def bytes_to_unicode(): |
| 14 | + bs = ( |
| 15 | + list(range(ord("!"), ord("~") + 1)) |
| 16 | + + list(range(ord("¡"), ord("¬") + 1)) |
| 17 | + + list(range(ord("®"), ord("ÿ") + 1)) |
| 18 | + ) |
| 19 | + cs = bs[:] |
| 20 | + n = 0 |
| 21 | + for b in range(2 ** 8): |
| 22 | + if b not in bs: |
| 23 | + bs.append(b) |
| 24 | + cs.append(2 ** 8 + n) |
| 25 | + n += 1 |
| 26 | + cs = [chr(n) for n in cs] |
| 27 | + return dict(zip(bs, cs)) |
| 28 | + |
| 29 | + |
| 30 | +def get_pairs(word): |
| 31 | + pairs = set() |
| 32 | + prev_char = word[0] |
| 33 | + for char in word[1:]: |
| 34 | + pairs.add((prev_char, char)) |
| 35 | + prev_char = char |
| 36 | + return pairs |
| 37 | + |
| 38 | + |
| 39 | +class Encoder: |
| 40 | + def __init__(self, encoder, bpe_merges, errors="replace"): |
| 41 | + self.encoder = encoder |
| 42 | + self.decoder = {v: k for k, v in self.encoder.items()} |
| 43 | + self.errors = errors |
| 44 | + self.byte_encoder = bytes_to_unicode() |
| 45 | + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} |
| 46 | + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) |
| 47 | + self.cache = {} |
| 48 | + self.pat = re.compile( |
| 49 | + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" |
| 50 | + ) |
| 51 | + |
| 52 | + def bpe(self, token): |
| 53 | + if token in self.cache: |
| 54 | + return self.cache[token] |
| 55 | + word = tuple(token) |
| 56 | + pairs = get_pairs(word) |
| 57 | + |
| 58 | + if not pairs: |
| 59 | + return token |
| 60 | + |
| 61 | + while True: |
| 62 | + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) |
| 63 | + if bigram not in self.bpe_ranks: |
| 64 | + break |
| 65 | + first, second = bigram |
| 66 | + new_word = [] |
| 67 | + i = 0 |
| 68 | + while i < len(word): |
| 69 | + try: |
| 70 | + j = word.index(first, i) |
| 71 | + new_word.extend(word[i:j]) |
| 72 | + i = j |
| 73 | + except: |
| 74 | + new_word.extend(word[i:]) |
| 75 | + break |
| 76 | + |
| 77 | + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: |
| 78 | + new_word.append(first + second) |
| 79 | + i += 2 |
| 80 | + else: |
| 81 | + new_word.append(word[i]) |
| 82 | + i += 1 |
| 83 | + new_word = tuple(new_word) |
| 84 | + word = new_word |
| 85 | + if len(word) == 1: |
| 86 | + break |
| 87 | + else: |
| 88 | + pairs = get_pairs(word) |
| 89 | + word = " ".join(word) |
| 90 | + self.cache[token] = word |
| 91 | + return word |
| 92 | + |
| 93 | + def encode(self, text): |
| 94 | + bpe_tokens = [] |
| 95 | + for token in re.findall(self.pat, text): |
| 96 | + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) |
| 97 | + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) |
| 98 | + return bpe_tokens |
| 99 | + |
| 100 | + def decode(self, tokens): |
| 101 | + text = "".join([self.decoder[token] for token in tokens]) |
| 102 | + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) |
| 103 | + return text |
| 104 | + |
| 105 | + |
| 106 | +def get_encoder(): |
| 107 | + s3 = boto3.client("s3") |
| 108 | + encoder = json.load( |
| 109 | + s3.get_object(Bucket="cortex-examples", Key="gpt-2/124M/encoder.json")["Body"] |
| 110 | + ) |
| 111 | + bpe_data = ( |
| 112 | + s3.get_object(Bucket="cortex-examples", Key="gpt-2/124M/vocab.bpe")["Body"] |
| 113 | + .read() |
| 114 | + .decode("utf-8") |
| 115 | + ) |
| 116 | + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] |
| 117 | + return Encoder(encoder=encoder, bpe_merges=bpe_merges) |
| 118 | + |
| 119 | + |
| 120 | +encoder = get_encoder() |
| 121 | + |
| 122 | + |
| 123 | +def pre_inference(sample, metadata): |
| 124 | + context = encoder.encode(sample["text"]) |
| 125 | + return {"context": [context]} |
| 126 | + |
| 127 | + |
| 128 | +def post_inference(prediction, metadata): |
| 129 | + return {encoder.decode(prediction["response"]["sample"])} |
0 commit comments