Skip to content

Commit

Permalink
map reduce support while finding dot product
Browse files Browse the repository at this point in the history
  • Loading branch information
serengil committed Feb 15, 2025
1 parent 2309a3d commit d4fa5e8
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 22 deletions.
10 changes: 6 additions & 4 deletions lightphe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,13 @@ def __build_cryptosystem(
return cs

def encrypt(
self, plaintext: Union[int, float, list]
self, plaintext: Union[int, float, list], silent: bool = False
) -> Union[Ciphertext, EncryptedTensor]:
"""
Encrypt a plaintext with a built cryptosystem
Args:
plaintext (int, float or tensor): message
silent (bool): set this to True if you do not want to see progress bar
Returns
ciphertext (from lightphe.models.Ciphertext import Ciphertext): encrypted message
"""
Expand All @@ -169,7 +170,7 @@ def encrypt(

if isinstance(plaintext, list):
# then encrypt tensors
return self.__encrypt_tensors(tensor=plaintext)
return self.__encrypt_tensors(tensor=plaintext, silent=silent)

ciphertext = self.cs.encrypt(
plaintext=phe_utils.normalize_input(
Expand Down Expand Up @@ -206,11 +207,12 @@ def decrypt(

return self.cs.decrypt(ciphertext=ciphertext.value)

def __encrypt_tensors(self, tensor: list) -> EncryptedTensor:
def __encrypt_tensors(self, tensor: list, silent: bool = False) -> EncryptedTensor:
"""
Encrypt a given tensor
Args:
tensor (list of int or float)
silent (bool): set this to True if you do not want to see progress bar
Returns
encrypted tensor (list of encrypted tensor object)
"""
Expand Down Expand Up @@ -243,7 +245,7 @@ def __encrypt_tensors(self, tensor: list) -> EncryptedTensor:
for f in tqdm(
funclist,
desc="Encrypting tensors",
disable=True if len(tensor) < 100 else False,
disable=silent,
):
result = f.get(timeout=10)
encrypted_tensor.append(result)
Expand Down
98 changes: 86 additions & 12 deletions lightphe/models/Tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# built-in dependencies
from typing import Union, List
import multiprocessing
from contextlib import closing

# 3rd party dependencies
from tqdm import tqdm

# project dependencies
from lightphe.models.Homomorphic import Homomorphic
from lightphe.commons import phe_utils
from lightphe.models.Ciphertext import Ciphertext
Expand Down Expand Up @@ -90,21 +98,65 @@ def __matmul__(self, other: Union["EncryptedTensor", list]):

encrypted_tensor = self.__mul__(other=other)

sum_dividend = cast_ciphertext(
cs=self.cs, value=encrypted_tensor.fractions[0].abs_dividend
)
if len(encrypted_tensor.fractions) == 0:
raise ValueError("Dot product cannot be calculated for empty tensor")

divisor = cast_ciphertext(
cs=self.cs, value=encrypted_tensor.fractions[0].divisor
)
for fraction in encrypted_tensor.fractions[1:]:
sum_dividend += cast_ciphertext(value=fraction.abs_dividend, cs=self.cs)

fraction = Fraction(
dividend=sum_dividend.value,
abs_dividend=sum_dividend.value,
divisor=divisor.value,
sign=1,
)

if len(encrypted_tensor.fractions) > 10000:
# parallelize the sum operation
num_workers = min(
len(encrypted_tensor.fractions), multiprocessing.cpu_count()
)

chunks = chunkify(encrypted_tensor.fractions, num_workers)

with closing(multiprocessing.Pool(num_workers)) as pool:
funclist = []

for chunk in chunks:
f = pool.apply_async(sum_fractions_chunk, (chunk, self.cs))
funclist.append(f)

partial_sums = []
for f in tqdm(funclist, desc="Summing up fractions", disable=True):
result = f.get(timeout=10)
partial_sums.append(result)

# map reduce
total_sum = partial_sums[0]
for partial in partial_sums[1:]:
total_sum += partial

fraction = Fraction(
dividend=total_sum.value,
abs_dividend=total_sum.value,
divisor=divisor.value,
sign=1,
)
else:
# serial implementation
sum_dividend = cast_ciphertext(
cs=self.cs, value=encrypted_tensor.fractions[0].abs_dividend
)
divisor = cast_ciphertext(
cs=self.cs, value=encrypted_tensor.fractions[0].divisor
)

if len(encrypted_tensor.fractions) > 1:
for fraction in encrypted_tensor.fractions[1:]:
sum_dividend += cast_ciphertext(
value=fraction.abs_dividend, cs=self.cs
)

fraction = Fraction(
dividend=sum_dividend.value,
abs_dividend=sum_dividend.value,
divisor=divisor.value,
sign=1,
)

return EncryptedTensor(fractions=[fraction], cs=self.cs)

Expand Down Expand Up @@ -295,10 +347,32 @@ def __add__(self, other: "EncryptedTensor") -> "EncryptedTensor":


def cast_ciphertext(cs: Homomorphic, value: int) -> Ciphertext:
"""Cast an integer value to a Ciphertext object."""
return Ciphertext(
algorithm_name=cs.__class__.__name__,
keys=cs.keys,
value=value,
form=cs.keys.get("form"),
curve=cs.keys.get("curve"),
)


def chunkify(lst: list, n: int):
"""Split list into n approximately equal chunks."""
avg = len(lst) // n
remainder = len(lst) % n
chunks = []
start = 0
for i in range(n):
end = start + avg + (1 if i < remainder else 0)
chunks.append(lst[start:end])
start = end
return chunks


def sum_fractions_chunk(fractions_chunk: list, cs: Homomorphic):
"""Compute the sum of a chunk of fractions in parallel."""
result = cast_ciphertext(cs=cs, value=fractions_chunk[0].abs_dividend)
for fraction in fractions_chunk[1:]:
result += cast_ciphertext(cs=cs, value=fraction.abs_dividend)
return result
26 changes: 20 additions & 6 deletions tests/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ def test_for_integer_tensor():
encrypted_similarity = enc_a @ b
decrypted_similarity = cs.decrypt(encrypted_similarity)[0]

assert (
abs(decrypted_similarity - expected_similarity) < 0.1
), f"expected {expected_similarity} but got {decrypted_similarity}"
assert abs(decrypted_similarity - expected_similarity) < 0.1, (
f"expected {expected_similarity} but got {decrypted_similarity}."
f"Diff = {abs(expected_similarity - decrypted_similarity)}"
)

logger.info("✅ Integer tensor tests succeeded")

Expand All @@ -236,15 +237,27 @@ def test_real_world_embedding():

# suppose that source and target embeddings are normalized vectors

source_embedding = [float(format(random.uniform(1, 2), ".17f")) for _ in range(128)]
# 3682

source_embedding = [
float(format(random.uniform(1, 2), ".17f")) for _ in range(4096)
]

# Randomly choose 3682 indices to set to zero - similar to VGG-Face
zero_indices = random.sample(range(4096), 3682)
for idx in zero_indices:
source_embedding[idx] = 0.0

logger.info(f"🤖 source image's embedding found - {len(source_embedding)}D")

tic = time.time()
source_embedding_encrypted = cs.encrypt(source_embedding)
toc = time.time()
logger.info(f"👨‍🔬 source embedding encrypted in {toc-tic} seconds")

target_embedding = [float(format(random.uniform(1, 2), ".17f")) for _ in range(128)]
target_embedding = [
float(format(random.uniform(1, 2), ".17f")) for _ in range(4096)
]
logger.info(f"🤖 target image's embedding found - {len(target_embedding)}D")

# dot product to calculate encrypted similarity
Expand All @@ -263,7 +276,8 @@ def test_real_world_embedding():
expected_similarity = sum(x * y for x, y in zip(source_embedding, target_embedding))

logger.info(
f"ℹ️ expected similarity: {expected_similarity}, got {decrypted_similarity}"
f"ℹ️ expected similarity: {expected_similarity}, got {decrypted_similarity}."
f"Difference: {abs(expected_similarity - decrypted_similarity)}."
)

assert (
Expand Down

0 comments on commit d4fa5e8

Please sign in to comment.