diff --git a/core.py b/core.py index 3c79861..ea2a87f 100644 --- a/core.py +++ b/core.py @@ -1,7 +1,6 @@ import numpy as np from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree -from lm import LanguageModel class LmAdversary: @@ -24,7 +23,7 @@ class Sender: plain_text - `cipher` -> cipher_text - `hide` -> stego_text ''' - def __init__(self, lm, cipher_key, cipher, cipher_text_length, tv_threshold, seed=None): + def __init__(self, lm, cipher_key, cipher, cipher_text_length, tv_threshold, max_sequence_length, seed=None): self.cipher_key = cipher_key self.cipher = cipher self.lm = lm @@ -32,6 +31,7 @@ def __init__(self, lm, cipher_key, cipher, cipher_text_length, tv_threshold, see self.cipher_text_length = cipher_text_length self.random = np.random.RandomState(seed) self.acc_risk = 0 + self.max_sequence_length = max_sequence_length def encrypt(self, plain_text): cipher_text = self.cipher(self.cipher_key, plain_text) @@ -57,7 +57,7 @@ def embed_bits(self, coin_flips): prefix = [ind] p = self.lm.p_next_token(prefix) # Terminate the generation after we generate the EOS token - while len(prefix) == 1 or ind != self.lm.EOS_ind: + while len(prefix) == 1 or (len(prefix) < self.max_sequence_length and ind != self.lm.EOS_ind): # There is still some cipher text to hide le = len(coin_flips) if le > 0: @@ -65,7 +65,8 @@ def embed_bits(self, coin_flips): heap = build_min_heap(p) hc = huffman_tree(heap) # Check if the total variation is low enough - if tv_huffman(hc, p) < self.tv_threshold: + print(tv_huffman(hc, p)) + if tv_huffman(hc, p)[0] < self.tv_threshold: # Huffman-decode the cipher text into a token # Consume the cipher text until a token is generated decoder_state = hc @@ -88,7 +89,7 @@ def embed_bits(self, coin_flips): prefix.append(ind) p = self.lm.p_next_token(prefix) # Drop the EOS index - return prefix[:1] + return prefix[1:] class Receiver: @@ -129,12 +130,12 @@ def recover_bits(self, token_inds, remaining_bits): cipher_text = [] # Terminate the generation after we have consumed all indices or # have extracted all bits - while 0 < len(token_inds) or remaining_bits == 0: + while 0 < len(token_inds) and 0 < remaining_bits: # Build Huffman codes for the conditional distribution heap = build_min_heap(p) hc = huffman_tree(heap) # Check if the total variation is low enough - if tv_huffman(hc, p) < self.tv_threshold: + if tv_huffman(hc, p)[0] < self.tv_threshold: # We have controlled this step. Some bits are hidden. code = invert_code_tree(hc) # Look up the Huffman code for the token. @@ -142,8 +143,10 @@ def recover_bits(self, token_inds, remaining_bits): # Convert the Huffman code into bits # left => 0, right => 1 cipher_text_fragment = [0 if bit == 'l' else 1 for bit in code[ind]] - cipher_text += cipher_text_fragment + # Truncate possible trailing paddings + cipher_text += cipher_text_fragment[:remaining_bits] remaining_bits -= len(cipher_text_fragment) + print(remaining_bits) prefix += [ind] p = self.lm.p_next_token(prefix) else: @@ -154,9 +157,31 @@ def recover_bits(self, token_inds, remaining_bits): if __name__ == '__main__': - lm = LanguageModel(['', '', 'a', 'b'], 0, 1) - cipher_text_length = 128 - tv_threshold = 0.5 + from gptlm import GptLanguageModel + lm = GptLanguageModel() + cipher_text_length = 32 + # tv_threshold = float('inf') + tv_threshold = 0.08 + alice = Sender(lm, None, None, cipher_text_length, tv_threshold, seed=123) - stego_text = alice.hide([0, 1, 0, 1, 1, 1, 0, 0]) - print(stego_text) + bob = Receiver(lm, None, None, cipher_text_length, tv_threshold) + + # sent_bits = list(np.random.choice(2, cipher_text_length)) + sent_bits = [1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0] + # sent_bits = [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0] + print(sent_bits) + stego_inds = alice.embed_bits(list(sent_bits)) + msg = lm.enc.decode(stego_inds) + + print(msg) + + token_inds = lm.enc.encode(msg) + recovered_bits = bob.recover_bits(token_inds, cipher_text_length) + print(recovered_bits) + + # Check + print(recovered_bits == sent_bits) + # stego_text = alice.hide(bits) + # print(stego_text) + # for seq in stego_text: + # print(''.join(seq)) diff --git a/gptlm.py b/gptlm.py index ff94938..0d057bf 100644 --- a/gptlm.py +++ b/gptlm.py @@ -1,18 +1,18 @@ #!/usr/bin/env python3 -import fire import json import os import numpy as np import tensorflow as tf +import encoder +from lm import LanguageModel + # import model, sample, encoder import importlib.util spec_model = importlib.util.spec_from_file_location("module.model", "./external/gpt-2/src/model.py") model = importlib.util.module_from_spec(spec_model) spec_model.loader.exec_module(model) -import encoder -from core import LanguageModel class GptLanguageModel(LanguageModel): @@ -43,19 +43,28 @@ def __init__(self, model_name='117M', seed=None, nsamples=1, batch_size=None, le ckpt = tf.train.latest_checkpoint(os.path.join(base_path, 'models', model_name)) saver.restore(self.sess, ckpt) + self.EOS_ind = self.SOS + self.SOS_ind = self.SOS + self.vocabulary_index = self.enc.encoder + self.vocabulary = self.enc.decoder + self.vocabulary_size = self.hparams.n_vocab + def p_next_token(self, prefix): # raw_text = prefix # if not raw_text: # print('Prompt should not be empty!') # raise ValueError("must have prefix tokens.") context_tokens = prefix - print('prefix', context_tokens) + # print('prefix', context_tokens) context_tk_reshape = np.asarray(context_tokens).reshape((self.batch_size, -1)) out = self.sess.run(self.lm_output, feed_dict={ self.context: context_tk_reshape}) - p_next_tk = out['logits'] - return p_next_tk[0, -1] + logits = out['logits'][0, -1] + max_logit = logits.max() + p = np.exp(logits - max_logit) + p /= p.sum() + return p def perplexity(self, sentence): sos_padding = np.array([self.SOS for i in range(self.batch_size)]).reshape((self.batch_size, -1)) @@ -73,10 +82,7 @@ def perplexity(self, sentence): if __name__ == '__main__': - def entropy(logits): - max_logit = logits.max() - p = np.exp(logits - max_logit) - p = p / p.sum() + def entropy(p): return -np.sum(p * np.log2(p)) # Example @@ -84,24 +90,26 @@ def entropy(logits): prefix = [lm.SOS] logits = lm.p_next_token(prefix) print(logits) - i = logits.argmax() - print(logits[i], lm.enc.decoder[i]) + inds = logits.argsort()[-10:] + print([lm.enc.decoder[i] for i in inds[::-1]]) print(entropy(logits)) - # High entropy for some prefixes - i_have_a_lot = lm.enc.encode('I have a lot') + # Low entropy for some prefixes + prefix = 'I have a lot' + i_have_a_lot = lm.enc.encode(prefix) logits = lm.p_next_token(i_have_a_lot) - print(logits.shape) + print(prefix) print(logits) - i = logits.argmax() - print(logits[i], lm.enc.decoder[i]) + inds = logits.argsort()[-10:] + print([lm.enc.decoder[i] for i in inds[::-1]]) print(entropy(logits)) - # Low entropy for other prefixes - the_capital_of_us_is = lm.enc.encode('The capital of USA is') - logits = lm.p_next_token(the_capital_of_us_is) - print(logits.shape) + # High entropy for some other prefixes + prefix = 'I like your' + i_like_your = lm.enc.encode(prefix) + logits = lm.p_next_token(i_like_your) + print(prefix) print(logits) - i = logits.argmax() - print(logits[i], lm.enc.decoder[i]) + inds = logits.argsort()[-10:] + print([lm.enc.decoder[i] for i in inds[::-1]]) print(entropy(logits)) diff --git a/lm.py b/lm.py index 6e974e1..b1d67e8 100644 --- a/lm.py +++ b/lm.py @@ -10,7 +10,7 @@ def __init__(self, vocabulary, SOS_ind, EOS_ind): # Mapping from indices to tokens self.vocabulary = vocabulary # Inverse map from tokens to indices - self.vocabulary_index = {token: ind for ind, token in self.vocabulary} + self.vocabulary_index = {token: ind for ind, token in enumerate(self.vocabulary)} self.vocabulary_size = len(self.vocabulary) def p_next_token(self, prefix):