Skip to content

Commit

Permalink
add seek algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
falcondai committed Feb 27, 2019
1 parent f5f40f7 commit 5fe3a8e
Showing 1 changed file with 48 additions and 13 deletions.
61 changes: 48 additions & 13 deletions core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from huffman import build_min_heap, huffman_tree, tv_huffman
from huffman import build_min_heap, huffman_tree, tv_huffman, invert_code_tree
from lm import LanguageModel


Expand Down Expand Up @@ -40,32 +40,28 @@ def encrypt(self, plain_text):
def hide(self, cipher_text):
'''We use the cipher text to control the forward sampling procedure
in sampling tokens from the prefix language model.'''
assert len(cipher_text) == self.cipher_text_length, 'Ciphertext must be of length %s.' % self.cipher_text_length
stego_text = []
while len(cipher_text) > 0:
inds = self.control_sample(cipher_text)
inds = self.embed_bits(cipher_text)
# Look up the tokens
stego_text += [self.lm.vocabulary[ind] for ind in inds]
# XXX keep sequences separate
stego_text.append([self.lm.vocabulary[ind] for ind in inds])
return stego_text

def control_sample(self, coin_flips):
def embed_bits(self, coin_flips):
'''We use a sequence of coin flips to control the generation of token
indices from a language model. This returns _a sequence_ as defined by
the model, e.g. sentence, paragraph.'''
the language model, e.g. sentence, paragraph.'''
ind = self.lm.SOS_ind
prefix = [ind]
p = self.lm.p_next_token(prefix)
# Terminate the generation after we generate the EOS token
while ind != self.lm.EOS_ind:
while len(prefix) == 1 or ind != self.lm.EOS_ind:
# There is still some cipher text to hide
le = len(coin_flips)
if le > 0:
# Build Huffman codes for the conditional distribution
# if 2 ** le < self.lm.vocabulary_size:
# # Truncate the distribution to the le-most likely tokens
# inds = np.argsort(p)[-le:]
# heap = build_min_heap(p, inds)
# else:
# heap = build_min_heap(p)
heap = build_min_heap(p)
hc = huffman_tree(heap)
# Check if the total variation is low enough
Expand Down Expand Up @@ -115,7 +111,46 @@ def decrypt(self, cipher_text):

def seek(self, stego_text):
'''Seek the hidden cipher text from the given stego text by following
the same sampling procedure.'''
the same forward sampling procedure.'''
cipher_text = []
remaining_bits = self.cipher_text_length
for seq in stego_text:
inds = [self.lm.vocabulary_index[token] for token in seq]
cipher_text_fragment = self.recover_bits(inds, remaining_bits)
cipher_text += cipher_text_fragment
remaining_bits -= len(cipher_text_fragment)
assert len(cipher_text) == self.cipher_text_length, 'Ciphertext must be of length %s.' % self.cipher_text_length
return cipher_text

def recover_bits(self, token_inds, remaining_bits):
ind = self.lm.SOS_ind
prefix = [ind]
p = self.lm.p_next_token(prefix)
cipher_text = []
# Terminate the generation after we have consumed all indices or
# have extracted all bits
while 0 < len(token_inds) or remaining_bits == 0:
# Build Huffman codes for the conditional distribution
heap = build_min_heap(p)
hc = huffman_tree(heap)
# Check if the total variation is low enough
if tv_huffman(hc, p) < self.tv_threshold:
# We have controlled this step. Some bits are hidden.
code = invert_code_tree(hc)
# Look up the Huffman code for the token.
ind = token_inds.pop(0)
# Convert the Huffman code into bits
# left => 0, right => 1
cipher_text_fragment = [0 if bit == 'l' else 1 for bit in code[ind]]
cipher_text += cipher_text_fragment
remaining_bits -= len(cipher_text_fragment)
prefix += [ind]
p = self.lm.p_next_token(prefix)
else:
# We did not control this step. Skip.
prefix.append(token_inds.pop(0))
p = self.lm.p_next_token(prefix)
return cipher_text


if __name__ == '__main__':
Expand Down

0 comments on commit 5fe3a8e

Please sign in to comment.