Skip to content

Commit

Permalink
🎉 integrate GPT-LM with sender/receiver. working example!
Browse files Browse the repository at this point in the history
  • Loading branch information
falcondai committed Feb 28, 2019
1 parent 5815a9f commit fdbdfb6
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 37 deletions.
51 changes: 38 additions & 13 deletions core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np

from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree
from lm import LanguageModel


class LmAdversary:
Expand All @@ -24,14 +23,15 @@ class Sender:
plain_text - `cipher` -> cipher_text - `hide` -> stego_text
'''

def __init__(self, lm, cipher_key, cipher, cipher_text_length, tv_threshold, seed=None):
def __init__(self, lm, cipher_key, cipher, cipher_text_length, tv_threshold, max_sequence_length, seed=None):
self.cipher_key = cipher_key
self.cipher = cipher
self.lm = lm
self.tv_threshold = tv_threshold
self.cipher_text_length = cipher_text_length
self.random = np.random.RandomState(seed)
self.acc_risk = 0
self.max_sequence_length = max_sequence_length

def encrypt(self, plain_text):
cipher_text = self.cipher(self.cipher_key, plain_text)
Expand All @@ -57,15 +57,16 @@ def embed_bits(self, coin_flips):
prefix = [ind]
p = self.lm.p_next_token(prefix)
# Terminate the generation after we generate the EOS token
while len(prefix) == 1 or ind != self.lm.EOS_ind:
while len(prefix) == 1 or (len(prefix) < self.max_sequence_length and ind != self.lm.EOS_ind):
# There is still some cipher text to hide
le = len(coin_flips)
if le > 0:
# Build Huffman codes for the conditional distribution
heap = build_min_heap(p)
hc = huffman_tree(heap)
# Check if the total variation is low enough
if tv_huffman(hc, p) < self.tv_threshold:
print(tv_huffman(hc, p))
if tv_huffman(hc, p)[0] < self.tv_threshold:
# Huffman-decode the cipher text into a token
# Consume the cipher text until a token is generated
decoder_state = hc
Expand All @@ -88,7 +89,7 @@ def embed_bits(self, coin_flips):
prefix.append(ind)
p = self.lm.p_next_token(prefix)
# Drop the EOS index
return prefix[:1]
return prefix[1:]


class Receiver:
Expand Down Expand Up @@ -129,21 +130,23 @@ def recover_bits(self, token_inds, remaining_bits):
cipher_text = []
# Terminate the generation after we have consumed all indices or
# have extracted all bits
while 0 < len(token_inds) or remaining_bits == 0:
while 0 < len(token_inds) and 0 < remaining_bits:
# Build Huffman codes for the conditional distribution
heap = build_min_heap(p)
hc = huffman_tree(heap)
# Check if the total variation is low enough
if tv_huffman(hc, p) < self.tv_threshold:
if tv_huffman(hc, p)[0] < self.tv_threshold:
# We have controlled this step. Some bits are hidden.
code = invert_code_tree(hc)
# Look up the Huffman code for the token.
ind = token_inds.pop(0)
# Convert the Huffman code into bits
# left => 0, right => 1
cipher_text_fragment = [0 if bit == 'l' else 1 for bit in code[ind]]
cipher_text += cipher_text_fragment
# Truncate possible trailing paddings
cipher_text += cipher_text_fragment[:remaining_bits]
remaining_bits -= len(cipher_text_fragment)
print(remaining_bits)
prefix += [ind]
p = self.lm.p_next_token(prefix)
else:
Expand All @@ -154,9 +157,31 @@ def recover_bits(self, token_inds, remaining_bits):


if __name__ == '__main__':
lm = LanguageModel(['<s>', '</s>', 'a', 'b'], 0, 1)
cipher_text_length = 128
tv_threshold = 0.5
from gptlm import GptLanguageModel
lm = GptLanguageModel()
cipher_text_length = 32
# tv_threshold = float('inf')
tv_threshold = 0.08

alice = Sender(lm, None, None, cipher_text_length, tv_threshold, seed=123)
stego_text = alice.hide([0, 1, 0, 1, 1, 1, 0, 0])
print(stego_text)
bob = Receiver(lm, None, None, cipher_text_length, tv_threshold)

# sent_bits = list(np.random.choice(2, cipher_text_length))
sent_bits = [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0]
# sent_bits = [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0]
print(sent_bits)
stego_inds = alice.embed_bits(list(sent_bits))
msg = lm.enc.decode(stego_inds)

print(msg)

token_inds = lm.enc.encode(msg)
recovered_bits = bob.recover_bits(token_inds, cipher_text_length)
print(recovered_bits)

# Check
print(recovered_bits == sent_bits)
# stego_text = alice.hide(bits)
# print(stego_text)
# for seq in stego_text:
# print(''.join(seq))
54 changes: 31 additions & 23 deletions gptlm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/usr/bin/env python3

import fire
import json
import os
import numpy as np
import tensorflow as tf

import encoder
from lm import LanguageModel

# import model, sample, encoder
import importlib.util
spec_model = importlib.util.spec_from_file_location("module.model", "./external/gpt-2/src/model.py")
model = importlib.util.module_from_spec(spec_model)
spec_model.loader.exec_module(model)
import encoder
from core import LanguageModel


class GptLanguageModel(LanguageModel):
Expand Down Expand Up @@ -43,19 +43,28 @@ def __init__(self, model_name='117M', seed=None, nsamples=1, batch_size=None, le
ckpt = tf.train.latest_checkpoint(os.path.join(base_path, 'models', model_name))
saver.restore(self.sess, ckpt)

self.EOS_ind = self.SOS
self.SOS_ind = self.SOS
self.vocabulary_index = self.enc.encoder
self.vocabulary = self.enc.decoder
self.vocabulary_size = self.hparams.n_vocab

def p_next_token(self, prefix):
# raw_text = prefix
# if not raw_text:
# print('Prompt should not be empty!')
# raise ValueError("must have prefix tokens.")
context_tokens = prefix
print('prefix', context_tokens)
# print('prefix', context_tokens)
context_tk_reshape = np.asarray(context_tokens).reshape((self.batch_size, -1))

out = self.sess.run(self.lm_output, feed_dict={
self.context: context_tk_reshape})
p_next_tk = out['logits']
return p_next_tk[0, -1]
logits = out['logits'][0, -1]
max_logit = logits.max()
p = np.exp(logits - max_logit)
p /= p.sum()
return p

def perplexity(self, sentence):
sos_padding = np.array([self.SOS for i in range(self.batch_size)]).reshape((self.batch_size, -1))
Expand All @@ -73,35 +82,34 @@ def perplexity(self, sentence):


if __name__ == '__main__':
def entropy(logits):
max_logit = logits.max()
p = np.exp(logits - max_logit)
p = p / p.sum()
def entropy(p):
return -np.sum(p * np.log2(p))

# Example
lm = GptLanguageModel()
prefix = [lm.SOS]
logits = lm.p_next_token(prefix)
print(logits)
i = logits.argmax()
print(logits[i], lm.enc.decoder[i])
inds = logits.argsort()[-10:]
print([lm.enc.decoder[i] for i in inds[::-1]])
print(entropy(logits))

# High entropy for some prefixes
i_have_a_lot = lm.enc.encode('I have a lot')
# Low entropy for some prefixes
prefix = 'I have a lot'
i_have_a_lot = lm.enc.encode(prefix)
logits = lm.p_next_token(i_have_a_lot)
print(logits.shape)
print(prefix)
print(logits)
i = logits.argmax()
print(logits[i], lm.enc.decoder[i])
inds = logits.argsort()[-10:]
print([lm.enc.decoder[i] for i in inds[::-1]])
print(entropy(logits))

# Low entropy for other prefixes
the_capital_of_us_is = lm.enc.encode('The capital of USA is')
logits = lm.p_next_token(the_capital_of_us_is)
print(logits.shape)
# High entropy for some other prefixes
prefix = 'I like your'
i_like_your = lm.enc.encode(prefix)
logits = lm.p_next_token(i_like_your)
print(prefix)
print(logits)
i = logits.argmax()
print(logits[i], lm.enc.decoder[i])
inds = logits.argsort()[-10:]
print([lm.enc.decoder[i] for i in inds[::-1]])
print(entropy(logits))
2 changes: 1 addition & 1 deletion lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, vocabulary, SOS_ind, EOS_ind):
# Mapping from indices to tokens
self.vocabulary = vocabulary
# Inverse map from tokens to indices
self.vocabulary_index = {token: ind for ind, token in self.vocabulary}
self.vocabulary_index = {token: ind for ind, token in enumerate(self.vocabulary)}
self.vocabulary_size = len(self.vocabulary)

def p_next_token(self, prefix):
Expand Down

0 comments on commit fdbdfb6

Please sign in to comment.