diff --git a/.github/workflows/publish_geth_testing.yml b/.github/workflows/publish_geth_testing.yml index b7cd41178d1b..b90ad79b0249 100644 --- a/.github/workflows/publish_geth_testing.yml +++ b/.github/workflows/publish_geth_testing.yml @@ -11,43 +11,40 @@ jobs: name: Build and test go-ethereum runs-on: ubuntu-latest steps: - - name: Check out tfhe-rs + - name: Checkout TFHE-rs uses: actions/checkout@v3 with: repository: zama-ai/tfhe-rs - ref: 0.2.1 + ref: 1d817c45d5234bcf33638406191b656998b30c2a path: tfhe-rs - - name: Check out tfhe-cli + - name: Checkout zbc-fhe-tool uses: actions/checkout@v3 with: - repository: tremblaythibaultl/tfhe-cli + repository: zama-ai/zbc-fhe-tool ref: main - path: tfhe-cli + token: ${{ secrets.CONCRETE_ACTIONS_TOKEN }} + path: zbc-fhe-tool - - name: Check out go-ethereum + - name: Checkout go-ethereum uses: actions/checkout@v3 with: path: go-ethereum - name: Build C API working-directory: ./tfhe-rs - run: make build_c_api + run: make build_c_api_experimental_deterministic_fft - name: Move library files run: | - mv ./tfhe-rs/target/release/tfhe.h ./go-ethereum/core/vm/ - sudo mv ./tfhe-rs/target/release/libtfhe.* /usr/lib/ + sudo mv ./tfhe-rs/target/release/tfhe.h /usr/include + sudo mv ./tfhe-rs/target/release/libtfhe.so /usr/lib/ - name: Generate TFHE-rs keys - working-directory: ./tfhe-cli - run: cargo run --release keygen bin . - - - name: Move keys + working-directory: ./zbc-fhe-tool run: | mkdir -p $HOME/.evmosd/zama/keys/network-fhe-keys - mv ./tfhe-cli/client_key.bin $HOME/.evmosd/zama/keys/network-fhe-keys/cks - mv ./tfhe-cli/server_key.bin $HOME/.evmosd/zama/keys/network-fhe-keys/sks + cargo run --features tfhe/x86_64-unix --release -- generate-keys -d $HOME/.evmosd/zama/keys/network-fhe-keys - name: Run tests working-directory: ./go-ethereum/core/vm diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 92e40e333166..8eaf1de0307d 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -1272,7 +1272,7 @@ func importCiphertext(accessibleState PrecompileAccessibleState, ct *tfheCiphert // Used when we want to skip FHE computation, e.g. gas estimation. func importRandomCiphertext(accessibleState PrecompileAccessibleState, t fheUintType) []byte { ct := new(tfheCiphertext) - ct.makeRandom(t) + ct.encrypt(*big.NewInt(0), t) importCiphertext(accessibleState, ct) ctHash := ct.getHash() return ctHash[:] @@ -1435,13 +1435,8 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller ctBytes := input[:len(input)-1] ctType := fheUintType(input[len(input)-1]) - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState, ctType), nil - } - ct := new(tfheCiphertext) - err := ct.deserialize(ctBytes, ctType) + err := ct.deserializeCompact(ctBytes, ctType) if err != nil { logger.Error("verifyCiphertext failed to deserialize input ciphertext", "err", err, @@ -1451,9 +1446,11 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller } ctHash := ct.getHash() importCiphertext(accessibleState, ct) - logger.Info("verifyCiphertext success", - "ctHash", ctHash.Hex(), - "ctBytes64", hex.EncodeToString(ctBytes[:minInt(len(ctBytes), 64)])) + if accessibleState.Interpreter().evm.Commit { + logger.Info("verifyCiphertext success", + "ctHash", ctHash.Hex(), + "ctBytes64", hex.EncodeToString(ctBytes[:minInt(len(ctBytes), 64)])) + } return ctHash.Bytes(), nil } diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 5ef24823947a..4be3d99c703f 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -437,8 +437,13 @@ func newTestState() *statefulPrecompileAccessibleState { } func verifyCiphertextInTestMemory(interpreter *EVMInterpreter, value uint64, depth int, t fheUintType) *tfheCiphertext { + // Simulate as if the ciphertext is compact and comes externally. + ser := encryptAndSerializeCompact(uint32(value), t) ct := new(tfheCiphertext) - ct.encrypt(*new(big.Int).SetUint64(value), t) + err := ct.deserializeCompact(ser, t) + if err != nil { + panic(err) + } return verifyTfheCiphertextInTestMemory(interpreter, ct, depth) } @@ -455,6 +460,68 @@ func toPrecompileInput(hashes ...common.Hash) []byte { return ret } +func VerifyCiphertext(t *testing.T, fheUintType fheUintType) { + var value uint32 + switch fheUintType { + case FheUint8: + value = 2 + case FheUint16: + value = 4283 + case FheUint32: + value = 1333337 + } + c := &verifyCiphertext{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + compact := encryptAndSerializeCompact(value, fheUintType) + input := append(compact, byte(fheUintType)) + out, err := c.Run(state, addr, addr, input, readOnly) + if err != nil { + t.Fatalf(err.Error()) + } + ct := new(tfheCiphertext) + if err = ct.deserializeCompact(compact, fheUintType); err != nil { + t.Fatalf(err.Error()) + } + if common.BytesToHash(out) != ct.getHash() { + t.Fatalf("output hash in verifyCipertext is incorrect") + } + res := getVerifiedCiphertextFromEVM(state.interpreter, ct.getHash()) + if res == nil { + t.Fatalf("verifyCiphertext must have verified given ciphertext") + } +} + +func VerifyCiphertextBadType(t *testing.T, actualType fheUintType, metadataType fheUintType) { + var value uint32 + switch actualType { + case FheUint8: + value = 2 + case FheUint16: + value = 4283 + case FheUint32: + value = 1333337 + } + c := &verifyCiphertext{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + compact := encryptAndSerializeCompact(value, actualType) + input := append(compact, byte(metadataType)) + _, err := c.Run(state, addr, addr, input, readOnly) + if err == nil { + t.Fatalf("verifyCiphertext must have failed on type mismatch") + } + if len(state.interpreter.verifiedCiphertexts) != 0 { + t.Fatalf("verifyCiphertext mustn't have verified given ciphertext") + } +} + func FheAdd(t *testing.T, fheUintType fheUintType) { var lhs, rhs uint64 switch fheUintType { @@ -673,6 +740,49 @@ func FheLt(t *testing.T, fheUintType fheUintType) { } } +func TestVerifyCiphertext8(t *testing.T) { + VerifyCiphertext(t, FheUint8) +} + +func TestVerifyCiphertext16(t *testing.T) { + VerifyCiphertext(t, FheUint16) +} + +func TestVerifyCiphertext32(t *testing.T) { + VerifyCiphertext(t, FheUint32) +} + +// func TestVerifyCiphertext8BadType(t *testing.T) { +// VerifyCiphertextBadType(t, FheUint8, FheUint16) +// VerifyCiphertextBadType(t, FheUint8, FheUint32) +// } + +// func TestVerifyCiphertext16BadType(t *testing.T) { +// VerifyCiphertextBadType(t, FheUint16, FheUint8) +// VerifyCiphertextBadType(t, FheUint16, FheUint32) +// } + +// func TestVerifyCiphertext32BadType(t *testing.T) { +// VerifyCiphertextBadType(t, FheUint32, FheUint8) +// VerifyCiphertextBadType(t, FheUint32, FheUint16) +// } + +func TestVerifyCiphertextBadCiphertext(t *testing.T) { + c := &verifyCiphertext{} + depth := 1 + state := newTestState() + state.interpreter.evm.depth = depth + addr := common.Address{} + readOnly := false + _, err := c.Run(state, addr, addr, make([]byte, 10), readOnly) + if err == nil { + t.Fatalf("verifyCiphertext must fail on bad ciphertext input") + } + if len(state.interpreter.verifiedCiphertexts) != 0 { + t.Fatalf("verifyCiphertext mustn't have verified given ciphertext") + } +} + func TestFheAdd8(t *testing.T) { FheAdd(t, FheUint8) } @@ -733,28 +843,6 @@ func TestFheLt32(t *testing.T) { FheLt(t, FheUint32) } -// func TestFheRand(t *testing.T) { -// c := &fheRand{} -// depth := 1 -// state := newTestState() -// state.interpreter.evm.depth = depth -// addr := common.Address{} -// readOnly := false - -// out, err := c.Run(state, addr, addr, nil, readOnly) -// if err != nil { -// t.Fatalf(err.Error()) -// } -// res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) -// if res == nil { -// t.Fatalf("output ciphertext is not found in verifiedCiphertexts") -// } -// decrypted := res.ciphertext.decrypt() -// if decrypted >= math.Pow(2, 3) { -// t.Fatalf("invalid decrypted result") -// } -// } - func TestUnknownCiphertextHandle(t *testing.T) { depth := 1 state := newTestState() diff --git a/core/vm/instructions.go b/core/vm/instructions.go index cd19bd0c30a1..ca89c21f799d 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -620,16 +620,18 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres if metadataInt.IsZero() { // If no metadata, it means this ciphertext itself hasn't been persisted to protected storage yet. We do that as part of SSTORE. metadata.refCount = 1 - metadata.length = uint64(fheCiphertextSize[verifiedCiphertext.ciphertext.fheUintType]) + metadata.length = uint64(expandedFheCiphertextSize[verifiedCiphertext.ciphertext.fheUintType]) metadata.fheUintType = verifiedCiphertext.ciphertext.fheUintType ciphertextSlot := newInt(val.Bytes()) ciphertextSlot.AddUint64(ciphertextSlot, 1) - logger.Info("opSstore persisting new ciphertext", - "protectedStorage", hex.EncodeToString(protectedStorage[:]), - "handle", hex.EncodeToString(val.Bytes()), - "type", metadata.fheUintType, - "len", metadata.length, - "ciphertextSlot", hex.EncodeToString(ciphertextSlot.Bytes())) + if interpreter.evm.Commit { + logger.Info("opSstore persisting new ciphertext", + "protectedStorage", hex.EncodeToString(protectedStorage[:]), + "handle", hex.EncodeToString(val.Bytes()), + "type", metadata.fheUintType, + "len", metadata.length, + "ciphertextSlot", hex.EncodeToString(ciphertextSlot.Bytes())) + } ctPart32 := make([]byte, 32) partIdx := 0 ctBytes := verifiedCiphertext.ciphertext.serialize() @@ -650,12 +652,14 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres // If metadata exists, bump the refcount by 1. metadata = *newCiphertextMetadata(interpreter.evm.StateDB.GetState(protectedStorage, val)) metadata.refCount++ - logger.Info("opSstore bumping refcount of existing ciphertext", - "protectedStorage", hex.EncodeToString(protectedStorage[:]), - "handle", hex.EncodeToString(val.Bytes()), - "type", metadata.fheUintType, - "len", metadata.length, - "refCount", metadata.refCount) + if interpreter.evm.Commit { + logger.Info("opSstore bumping refcount of existing ciphertext", + "protectedStorage", hex.EncodeToString(protectedStorage[:]), + "handle", hex.EncodeToString(val.Bytes()), + "type", metadata.fheUintType, + "len", metadata.length, + "refCount", metadata.refCount) + } } // Save the metadata in protected storage. interpreter.evm.StateDB.SetState(protectedStorage, val, metadata.serialize()) @@ -669,11 +673,13 @@ func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage co logger := interpreter.evm.Logger metadata := newCiphertextMetadata(existingMetadataInt.Bytes32()) if metadata.refCount == 1 { - logger.Info("opSstore garbage-collecting ciphertext", - "protectedStorage", hex.EncodeToString(protectedStorage[:]), - "metadataKey", hex.EncodeToString(metadataKey[:]), - "type", metadata.fheUintType, - "len", metadata.length) + if interpreter.evm.Commit { + logger.Info("opSstore garbage-collecting ciphertext", + "protectedStorage", hex.EncodeToString(protectedStorage[:]), + "metadataKey", hex.EncodeToString(metadataKey[:]), + "type", metadata.fheUintType, + "len", metadata.length) + } // Zero the metadata key-value. interpreter.evm.StateDB.SetState(protectedStorage, metadataKey, zero) @@ -692,11 +698,13 @@ func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage co slot.AddUint64(slot, 1) } } else if metadata.refCount > 1 { - logger.Info("opSstore decrementing ciphertext refCount", - "protectedStorage", hex.EncodeToString(protectedStorage[:]), - "metadataKey", hex.EncodeToString(metadataKey[:]), - "type", metadata.fheUintType, - "len", metadata.length) + if interpreter.evm.Commit { + logger.Info("opSstore decrementing ciphertext refCount", + "protectedStorage", hex.EncodeToString(protectedStorage[:]), + "metadataKey", hex.EncodeToString(metadataKey[:]), + "type", metadata.fheUintType, + "len", metadata.length) + } metadata.refCount-- interpreter.evm.StateDB.SetState(protectedStorage, existingMetadataHash, metadata.serialize()) } @@ -713,10 +721,7 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b newValHash := common.BytesToHash(newValBytes) oldValHash := interpreter.evm.StateDB.GetState(scope.Contract.Address(), common.Hash(loc.Bytes32())) protectedStorage := crypto.CreateProtectedStorageContractAddress(scope.Contract.Address()) - // Here, we assume that if the `Commit` flag is not set, no precompile would read/write an actual ciphertext - // from/to protected storage. Instead, all precompiles just insert random ciphertexts to memory. - // Therefore, if `Commit` is not set, we don't need to touch protected storage at all. - if interpreter.evm.Commit && newValHash != oldValHash { + if newValHash != oldValHash { // Since the old value is no longer stored in actual contract storage, run garbage collection on protected storage. garbageCollectProtectedStorage(oldValHash, protectedStorage, interpreter) // If a verified ciphertext, persist to protected storage. diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 298c09c0ef55..56ac480bc614 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -713,7 +713,7 @@ func (c testCallerAddress) Address() common.Address { func newTestScopeConext() *ScopeContext { c := new(ScopeContext) c.Memory = NewMemory() - c.Memory.Resize(uint64(fheCiphertextSize[FheUint8]) * 3) + c.Memory.Resize(uint64(expandedFheCiphertextSize[FheUint8]) * 3) c.Stack = newstack() c.Contract = NewContract(testCallerAddress{}, testContractAddress{}, big.NewInt(10), 100000) return c @@ -732,7 +732,7 @@ func TestProtectedStorageSstoreSload(t *testing.T) { depth := 1 interpreter := newTestInterpreter() interpreter.evm.depth = depth - ct := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8) + ct := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint32) ctHash := ct.getHash() scope := newTestScopeConext() loc := uint256.NewInt(10) @@ -791,8 +791,8 @@ func TestProtectedStorageGarbageCollection(t *testing.T) { if metadata.refCount != 1 { t.Fatalf("metadata.refcount of ciphertext is not 1") } - if metadata.length != uint64(fheCiphertextSize[FheUint8]) { - t.Fatalf("metadata.length (%v) != ciphertext len (%v)", metadata.length, uint64(fheCiphertextSize[FheUint8])) + if metadata.length != uint64(expandedFheCiphertextSize[FheUint8]) { + t.Fatalf("metadata.length (%v) != ciphertext len (%v)", metadata.length, uint64(expandedFheCiphertextSize[FheUint8])) } ciphertextLocationsToCheck := (metadata.length + 32 - 1) / 32 startOfCiphertext := newInt(ctHash[:]) diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index e2099fdf6c6d..fc3f5db07ea5 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -17,10 +17,10 @@ package vm /* -#cgo CFLAGS: -O3 +#cgo CFLAGS: -O3 -I. #cgo LDFLAGS: -Llib -ltfhe -#include "tfhe.h" +#include #undef NDEBUG #include @@ -39,6 +39,13 @@ void* deserialize_client_key(BufferView in) { return cks; } +void* deserialize_compact_public_key(BufferView in) { + CompactPublicKey* pks = NULL; + const int r = compact_public_key_deserialize(in, &pks); + assert(r == 0); + return pks; +} + void checked_set_server_key(void *sks) { const int r = set_server_key(sks); assert(r == 0); @@ -58,6 +65,31 @@ void* deserialize_fhe_uint8(BufferView in) { return ct; } +void* deserialize_compact_fhe_uint8(BufferView in) { + CompactFheUint8List* list = NULL; + FheUint8* ct = NULL; + + int r = compact_fhe_uint8_list_deserialize(in, &list); + if(r != 0) { + return NULL; + } + size_t len = 0; + r = compact_fhe_uint8_list_len(list, &len); + // Expect only 1 ciphertext in the list. + if(r != 0 || len != 1) { + r = compact_fhe_uint8_list_destroy(list); + assert(r == 0); + return NULL; + } + r = compact_fhe_uint8_list_expand(list, &ct, 1); + if(r != 0) { + ct = NULL; + } + r = compact_fhe_uint8_list_destroy(list); + assert(r == 0); + return ct; +} + void serialize_fhe_uint16(void *ct, Buffer* out) { const int r = fhe_uint16_serialize(ct, out); assert(r == 0); @@ -72,6 +104,31 @@ void* deserialize_fhe_uint16(BufferView in) { return ct; } +void* deserialize_compact_fhe_uint16(BufferView in) { + CompactFheUint16List* list = NULL; + FheUint16* ct = NULL; + + int r = compact_fhe_uint16_list_deserialize(in, &list); + if(r != 0) { + return NULL; + } + size_t len = 0; + r = compact_fhe_uint16_list_len(list, &len); + // Expect only 1 ciphertext in the list. + if(r != 0 || len != 1) { + r = compact_fhe_uint16_list_destroy(list); + assert(r == 0); + return NULL; + } + r = compact_fhe_uint16_list_expand(list, &ct, 1); + if(r != 0) { + ct = NULL; + } + r = compact_fhe_uint16_list_destroy(list); + assert(r == 0); + return ct; +} + void serialize_fhe_uint32(void *ct, Buffer* out) { const int r = fhe_uint32_serialize(ct, out); assert(r == 0); @@ -86,6 +143,31 @@ void* deserialize_fhe_uint32(BufferView in) { return ct; } +void* deserialize_compact_fhe_uint32(BufferView in) { + CompactFheUint32List* list = NULL; + FheUint32* ct = NULL; + + int r = compact_fhe_uint32_list_deserialize(in, &list); + if(r != 0) { + return NULL; + } + size_t len = 0; + r = compact_fhe_uint32_list_len(list, &len); + // Expect only 1 ciphertext in the list. + if(r != 0 || len != 1) { + r = compact_fhe_uint32_list_destroy(list); + assert(r == 0); + return NULL; + } + r = compact_fhe_uint32_list_expand(list, &ct, 1); + if(r != 0) { + ct = NULL; + } + r = compact_fhe_uint32_list_destroy(list); + assert(r == 0); + return ct; +} + void destroy_fhe_uint8(void* ct) { fhe_uint8_destroy(ct); } @@ -287,129 +369,121 @@ uint32_t decrypt_fhe_uint32(void* cks, void* ct) return res; } -void client_key_encrypt_and_ser_fhe_uint8(void* cks, uint8_t value, Buffer* out) { +void* public_key_encrypt_fhe_uint8(void* pks, uint8_t value) { + CompactFheUint8List* list = NULL; FheUint8* ct = NULL; - const int encrypt_ok = fhe_uint8_try_encrypt_with_client_key_u8(value, cks, &ct); - assert(encrypt_ok == 0); + int r = compact_fhe_uint8_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint8_list_expand(list, &ct, 1); + assert(r == 0); - const int ser_ok = fhe_uint8_serialize(ct, out); - assert(ser_ok == 0); + r = compact_fhe_uint8_list_destroy(list); + assert(r == 0); - fhe_uint8_destroy(ct); + return ct; } -void client_key_encrypt_and_ser_fhe_uint16(void* cks, uint16_t value, Buffer* out) { +void* public_key_encrypt_fhe_uint16(void* pks, uint16_t value) { + CompactFheUint16List* list = NULL; FheUint16* ct = NULL; - const int encrypt_ok = fhe_uint16_try_encrypt_with_client_key_u16(value, cks, &ct); - assert(encrypt_ok == 0); + int r = compact_fhe_uint16_list_try_encrypt_with_compact_public_key_u16(&value, 1, pks, &list); + assert(r == 0); - const int ser_ok = fhe_uint16_serialize(ct, out); - assert(ser_ok == 0); + r = compact_fhe_uint16_list_expand(list, &ct, 1); + assert(r == 0); - fhe_uint16_destroy(ct); + r = compact_fhe_uint16_list_destroy(list); + assert(r == 0); + + return ct; } -void client_key_encrypt_and_ser_fhe_uint32(void* cks, uint32_t value, Buffer* out) { +void* public_key_encrypt_fhe_uint32(void* pks, uint32_t value) { + CompactFheUint32List* list = NULL; FheUint32* ct = NULL; - const int encrypt_ok = fhe_uint32_try_encrypt_with_client_key_u32(value, cks, &ct); - assert(encrypt_ok == 0); + int r = compact_fhe_uint32_list_try_encrypt_with_compact_public_key_u32(&value, 1, pks, &list); + assert(r == 0); + + r = compact_fhe_uint32_list_expand(list, &ct, 1); + assert(r == 0); - const int ser_ok = fhe_uint32_serialize(ct, out); - assert(ser_ok == 0); + r = compact_fhe_uint32_list_destroy(list); + assert(r == 0); - fhe_uint32_destroy(ct); + return ct; } -void* client_key_encrypt_fhe_uint8(void* cks, uint8_t value) { +void* trivial_encrypt_fhe_uint8(void* sks, uint8_t value) { FheUint8* ct = NULL; - const int r = fhe_uint8_try_encrypt_with_client_key_u8(value, cks, &ct); + checked_set_server_key(sks); + + int r = fhe_uint8_try_encrypt_trivial_u8(value, &ct); assert(r == 0); return ct; } -void* client_key_encrypt_fhe_uint16(void* cks, uint16_t value) { +void* trivial_encrypt_fhe_uint16(void* sks, uint16_t value) { FheUint16* ct = NULL; - const int r = fhe_uint16_try_encrypt_with_client_key_u16(value, cks, &ct); + checked_set_server_key(sks); + + int r = fhe_uint16_try_encrypt_trivial_u16(value, &ct); assert(r == 0); return ct; } -void* client_key_encrypt_fhe_uint32(void* cks, uint32_t value) { +void* trivial_encrypt_fhe_uint32(void* sks, uint32_t value) { FheUint32* ct = NULL; - const int r = fhe_uint32_try_encrypt_with_client_key_u32(value, cks, &ct); + checked_set_server_key(sks); + + int r = fhe_uint32_try_encrypt_trivial_u32(value, &ct); assert(r == 0); return ct; } -void public_key_encrypt_fhe_uint8(BufferView pks_buf, uint8_t value, Buffer* out) -{ - FheUint8 *ct = NULL; - PublicKey *pks = NULL; - - const int deser_ok = public_key_deserialize(pks_buf, &pks); - assert(deser_ok == 0); +void public_key_encrypt_and_serialize_fhe_uint8_list(void* pks, uint8_t value, Buffer* out) { + CompactFheUint8List* list = NULL; - const int encrypt_ok = fhe_uint8_try_encrypt_with_public_key_u8(value, pks, &ct); - assert(encrypt_ok == 0); - - const int ser_ok = fhe_uint8_serialize(ct, out); - assert(ser_ok == 0); + int r = compact_fhe_uint8_list_try_encrypt_with_compact_public_key_u8(&value, 1, pks, &list); + assert(r == 0); - public_key_destroy(pks); - fhe_uint8_destroy(ct); + r = compact_fhe_uint8_list_serialize(list, out); + assert(r == 0); } -void public_key_encrypt_fhe_uint16(BufferView pks_buf, uint16_t value, Buffer* out) -{ - FheUint16 *ct = NULL; - PublicKey *pks = NULL; - - const int deser_ok = public_key_deserialize(pks_buf, &pks); - assert(deser_ok == 0); +void public_key_encrypt_and_serialize_fhe_uint16_list(void* pks, uint16_t value, Buffer* out) { + CompactFheUint16List* list = NULL; - const int encrypt_ok = fhe_uint16_try_encrypt_with_public_key_u16(value, pks, &ct); - assert(encrypt_ok == 0); - - const int ser_ok = fhe_uint16_serialize(ct, out); - assert(ser_ok == 0); + int r = compact_fhe_uint16_list_try_encrypt_with_compact_public_key_u16(&value, 1, pks, &list); + assert(r == 0); - public_key_destroy(pks); - fhe_uint16_destroy(ct); + r = compact_fhe_uint16_list_serialize(list, out); + assert(r == 0); } -void public_key_encrypt_fhe_uint32(BufferView pks_buf, uint32_t value, Buffer* out) -{ - FheUint32 *ct = NULL; - PublicKey *pks = NULL; - - const int deser_ok = public_key_deserialize(pks_buf, &pks); - assert(deser_ok == 0); +void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, Buffer* out) { + CompactFheUint32List* list = NULL; - const int encrypt_ok = fhe_uint32_try_encrypt_with_public_key_u32(value, pks, &ct); - assert(encrypt_ok == 0); - - const int ser_ok = fhe_uint32_serialize(ct, out); - assert(ser_ok == 0); + int r = compact_fhe_uint32_list_try_encrypt_with_compact_public_key_u32(&value, 1, pks, &list); + assert(r == 0); - public_key_destroy(pks); - fhe_uint32_destroy(ct); + r = compact_fhe_uint32_list_serialize(list, out); + assert(r == 0); } + */ import "C" -// TODO trivial encrypt - import ( - "crypto/rand" "errors" "fmt" "math/big" @@ -438,11 +512,13 @@ func homeDir() string { return home } -// The TFHE ciphertext size, in bytes. -var fheCiphertextSize map[fheUintType]uint +// TFHE ciphertext sizes by type, in bytes. +// Note: These sizes are for expanded (non-compacted) ciphertexts. +var expandedFheCiphertextSize map[fheUintType]uint var sks unsafe.Pointer var cks unsafe.Pointer +var pks unsafe.Pointer var networkKeysDir string var usersKeysDir string @@ -460,40 +536,41 @@ func runGc() { } func init() { - fheCiphertextSize = make(map[fheUintType]uint) - - fheCiphertextSize[FheUint8] = 28124 - fheCiphertextSize[FheUint16] = 56236 - fheCiphertextSize[FheUint32] = 112460 + expandedFheCiphertextSize = make(map[fheUintType]uint) go runGc() - + home := homeDir() networkKeysDir = home + "/.evmosd/zama/keys/network-fhe-keys/" usersKeysDir = home + "/.evmosd/zama/keys/users-fhe-keys/" sksBytes, err := os.ReadFile(networkKeysDir + "sks") if err != nil { - fmt.Print("WARNING: file sks not found.\n") + fmt.Println("WARNING: file sks not found.") return } sks = C.deserialize_server_key(toBufferView(sksBytes)) + expandedFheCiphertextSize[FheUint8] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint8).serialize())) + expandedFheCiphertextSize[FheUint16] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint16).serialize())) + expandedFheCiphertextSize[FheUint32] = uint(len(new(tfheCiphertext).trivialEncrypt(*big.NewInt(0), FheUint32).serialize())) + cksBytes, err := os.ReadFile(networkKeysDir + "cks") if err != nil { - fmt.Print("WARNING: file cks not found.\n") + fmt.Println("WARNING: file cks not found.") return } cks = C.deserialize_client_key(toBufferView(cksBytes)) - // Cannot use trivial encryption yet as it is not exposed by tfhe-rs - // ct := new(tfheCiphertext) - // ct.trivialEncrypt(1) - // fheCiphertextSize = len(ct.serialize()) + pksBytes, err := os.ReadFile(networkKeysDir + "pks") + if err != nil { + fmt.Println("WARNING: file pks not found.") + return + } + pks = C.deserialize_compact_public_key(toBufferView(pksBytes)) } -// Represents a TFHE ciphertext type (i.e., its bit capacity) - +// Represents a TFHE ciphertext type, i.e. its bit capacity. type fheUintType uint8 const ( @@ -502,19 +579,19 @@ const ( FheUint32 fheUintType = 2 ) -// Represents a TFHE ciphertext. +// Represents an expanded TFHE ciphertext. // -// Once a ciphertext has a value (either from deserialization, encryption or makeRandom()), -// it must not be set another value. If that is needed, a new ciphertext must be created. +// Once a ciphertext has a value (i.e. from deserialization), it must not be set +// another value. If that is needed, a new ciphertext must be created. type tfheCiphertext struct { ptr unsafe.Pointer serialization []byte hash []byte value *big.Int - random bool fheUintType fheUintType } +// Deserializes a TFHE ciphertext. func (ct *tfheCiphertext) deserialize(in []byte, t fheUintType) error { if ct.initialized() { panic("cannot deserialize to an existing ciphertext") @@ -529,7 +606,7 @@ func (ct *tfheCiphertext) deserialize(in []byte, t fheUintType) error { ptr = C.deserialize_fhe_uint32(toBufferView((in))) } if ptr == nil { - return errors.New("tfhe ciphertext deserialization failed") + return errors.New("TFHE ciphertext deserialization failed") } ct.setPtr(ptr) ct.fheUintType = t @@ -537,41 +614,68 @@ func (ct *tfheCiphertext) deserialize(in []byte, t fheUintType) error { return nil } -func (ct *tfheCiphertext) encrypt(value big.Int, t fheUintType) { +// Deserializes a compact TFHE ciphetext. +// Note: After the compact thfe ciphertext has been serialized, subsequent calls to serialize() +// will produce non-compact ciphertext serialziations. +func (ct *tfheCiphertext) deserializeCompact(in []byte, t fheUintType) error { + if ct.initialized() { + panic("cannot deserialize to an existing ciphertext") + } + var ptr unsafe.Pointer + switch t { + case FheUint8: + ptr = C.deserialize_compact_fhe_uint8(toBufferView((in))) + case FheUint16: + ptr = C.deserialize_compact_fhe_uint16(toBufferView((in))) + case FheUint32: + ptr = C.deserialize_compact_fhe_uint32(toBufferView((in))) + } + if ptr == nil { + return errors.New("TFHE ciphertext deserialization failed") + } + ct.setPtr(ptr) + ct.fheUintType = t + return nil +} + +// Encrypts a value as a TFHE ciphertext, using the compact public FHE key. +// The resulting ciphertext is automaticaly expanded. +func (ct *tfheCiphertext) encrypt(value big.Int, t fheUintType) *tfheCiphertext { if ct.initialized() { panic("cannot encrypt to an existing ciphertext") } switch t { case FheUint8: - ct.setPtr(C.client_key_encrypt_fhe_uint8(cks, C.uchar(value.Uint64()))) + ct.setPtr(C.public_key_encrypt_fhe_uint8(pks, C.uint8_t(value.Uint64()))) case FheUint16: - ct.setPtr(C.client_key_encrypt_fhe_uint16(cks, C.ushort(value.Uint64()))) + ct.setPtr(C.public_key_encrypt_fhe_uint16(pks, C.uint16_t(value.Uint64()))) case FheUint32: - ct.setPtr(C.client_key_encrypt_fhe_uint32(cks, C.uint(value.Uint64()))) + ct.setPtr(C.public_key_encrypt_fhe_uint32(pks, C.uint32_t(value.Uint64()))) } ct.fheUintType = t ct.value = &value + return ct } -func (ct *tfheCiphertext) makeRandom(t fheUintType) { +func (ct *tfheCiphertext) trivialEncrypt(value big.Int, t fheUintType) *tfheCiphertext { if ct.initialized() { - panic("cannot make an existing ciphertext random") + panic("cannot encrypt to an existing ciphertext") + } + + switch t { + case FheUint8: + ct.setPtr(C.trivial_encrypt_fhe_uint8(sks, C.uint8_t(value.Uint64()))) + case FheUint16: + ct.setPtr(C.trivial_encrypt_fhe_uint16(sks, C.uint16_t(value.Uint64()))) + case FheUint32: + ct.setPtr(C.trivial_encrypt_fhe_uint32(sks, C.uint32_t(value.Uint64()))) } - ct.serialization = make([]byte, fheCiphertextSize[t]) - rand.Read(ct.serialization) ct.fheUintType = t - ct.random = true + ct.value = &value + return ct } -// func (ct *tfheCiphertext) trivialEncrypt(value uint64) { -// if ct.initialized() { -// panic("cannot trivially encrypt to an existing ciphertext") -// } -// ct.setPtr(C.trivial_encrypt(sks, C.ulong(value))) -// ct.value = &value -// } - func (ct *tfheCiphertext) serialize() []byte { if !ct.initialized() { panic("cannot serialize a non-initialized ciphertext") @@ -754,39 +858,26 @@ func (ct *tfheCiphertext) getHash() common.Hash { } func (ct *tfheCiphertext) availableForOps() bool { - return (ct.initialized() && ct.ptr != nil && !ct.random) + return (ct.initialized() && ct.ptr != nil) } func (ct *tfheCiphertext) initialized() bool { - return (ct.ptr != nil || ct.random) + return (ct.ptr != nil) } -func clientKeyEncrypt(value uint64, t fheUintType) []byte { +// Used for testing. +func encryptAndSerializeCompact(value uint32, fheUintType fheUintType) []byte { out := &C.Buffer{} - switch t { + switch fheUintType { case FheUint8: - C.client_key_encrypt_and_ser_fhe_uint8(cks, C.uchar(value), out) + C.public_key_encrypt_and_serialize_fhe_uint8_list(pks, C.uint8_t(value), out) case FheUint16: - C.client_key_encrypt_and_ser_fhe_uint16(cks, C.ushort(value), out) + C.public_key_encrypt_and_serialize_fhe_uint16_list(pks, C.uint16_t(value), out) case FheUint32: - C.client_key_encrypt_and_ser_fhe_uint32(cks, C.uint(value), out) + C.public_key_encrypt_and_serialize_fhe_uint32_list(pks, C.uint32_t(value), out) } - result := C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) - C.destroy_buffer(out) - return result -} -func publicKeyEncrypt(pks []byte, value uint64, t fheUintType) []byte { - out := &C.Buffer{} - switch t { - case FheUint8: - C.public_key_encrypt_fhe_uint8(toBufferView(pks), C.uchar(value), out) - case FheUint16: - C.public_key_encrypt_fhe_uint16(toBufferView(pks), C.ushort(value), out) - case FheUint32: - C.public_key_encrypt_fhe_uint32(toBufferView(pks), C.uint(value), out) - } - result := C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) + ser := C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) C.destroy_buffer(out) - return result + return ser } diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index 9386a24502c0..e7d05ac6d7e4 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -25,7 +25,7 @@ import ( // TODO: Don't rely on global keys that are loaded from disk in init(). Instead, // generate keys on demand in the test. -func TfheCksEncryptDecrypt(t *testing.T, fheUintType fheUintType) { +func TfheEncryptDecrypt(t *testing.T, fheUintType fheUintType) { var val big.Int switch fheUintType { case FheUint8: @@ -43,8 +43,50 @@ func TfheCksEncryptDecrypt(t *testing.T, fheUintType fheUintType) { } } +func TfheTrivialEncryptDecrypt(t *testing.T, fheUintType fheUintType) { + var val big.Int + switch fheUintType { + case FheUint8: + val.SetUint64(2) + case FheUint16: + val.SetUint64(1337) + case FheUint32: + val.SetUint64(1333337) + } + ct := new(tfheCiphertext) + ct.trivialEncrypt(val, fheUintType) + res := ct.decrypt() + if res.Uint64() != val.Uint64() { + t.Fatalf("%d != %d", val.Uint64(), res.Uint64()) + } +} + func TfheSerializeDeserialize(t *testing.T, fheUintType fheUintType) { - var val uint64 + var val big.Int + switch fheUintType { + case FheUint8: + val = *big.NewInt(2) + case FheUint16: + val = *big.NewInt(1337) + case FheUint32: + val = *big.NewInt(1333337) + } + ct1 := new(tfheCiphertext) + ct1.encrypt(val, fheUintType) + ct1Ser := ct1.serialize() + ct2 := new(tfheCiphertext) + err := ct2.deserialize(ct1Ser, fheUintType) + if err != nil { + t.Fatalf("deserialization failed") + } + ct2Ser := ct2.serialize() + if !bytes.Equal(ct1Ser, ct2Ser) { + t.Fatalf("serialization is non-deterministic") + } +} + +func TfheSerializeDeserializeCompact(t *testing.T, fheUintType fheUintType) { + var val uint32 switch fheUintType { case FheUint8: val = 2 @@ -53,15 +95,53 @@ func TfheSerializeDeserialize(t *testing.T, fheUintType fheUintType) { case FheUint32: val = 1333337 } - ctBytes := clientKeyEncrypt(val, fheUintType) - ct := new(tfheCiphertext) - err := ct.deserialize(ctBytes, fheUintType) + + ser := encryptAndSerializeCompact(val, fheUintType) + ct1 := new(tfheCiphertext) + err := ct1.deserializeCompact(ser, fheUintType) + if err != nil { + t.Fatalf("ct1 compact deserialization failed") + } + ct1Ser := ct1.serialize() + + ct2 := new(tfheCiphertext) + err = ct2.deserialize(ct1Ser, fheUintType) + if err != nil { + t.Fatalf("ct2 deserialization failed") + } + + ct2Ser := ct2.serialize() + if !bytes.Equal(ct1Ser, ct2Ser) { + t.Fatalf("serialization is non-deterministic") + } + + decrypted := ct2.decrypt() + if uint32(decrypted.Uint64()) != val { + t.Fatalf("decrypted value is incorrect") + } +} + +func TfheTrivialSerializeDeserialize(t *testing.T, fheUintType fheUintType) { + var val big.Int + switch fheUintType { + case FheUint8: + val = *big.NewInt(2) + case FheUint16: + val = *big.NewInt(1337) + case FheUint32: + val = *big.NewInt(1333337) + } + ct1 := new(tfheCiphertext) + ct1.trivialEncrypt(val, fheUintType) + ct1Ser := ct1.serialize() + ct2 := new(tfheCiphertext) + err := ct2.deserialize(ct1Ser, fheUintType) if err != nil { t.Fatalf("deserialization failed") } - serialized := ct.serialize() - if !bytes.Equal(serialized, ctBytes) { - t.Fatalf("serialization failed") + ct2Ser := ct2.serialize() + if !bytes.Equal(ct1Ser, ct2Ser) { + t.Fatalf("trivial serialization is non-deterministic") } } @@ -73,6 +153,36 @@ func TfheDeserializeFailure(t *testing.T, fheUintType fheUintType) { } } +func TfheDeserializeCompact(t *testing.T, fheUintType fheUintType) { + var val uint32 + switch fheUintType { + case FheUint8: + val = 2 + case FheUint16: + val = 1337 + case FheUint32: + val = 1333337 + } + ser := encryptAndSerializeCompact(val, fheUintType) + ct := new(tfheCiphertext) + err := ct.deserializeCompact(ser, fheUintType) + if err != nil { + t.Fatalf("compact deserialization failed") + } + decryptedVal := ct.decrypt() + if uint32(decryptedVal.Uint64()) != val { + t.Fatalf("compact deserialization wrong decryption") + } +} + +func TfheDeserializeCompactFailure(t *testing.T, fheUintType fheUintType) { + ct := new(tfheCiphertext) + err := ct.deserializeCompact(make([]byte, 10), fheUintType) + if err == nil { + t.Fatalf("compact deserialization must have failed") + } +} + func TfheAdd(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { @@ -176,6 +286,7 @@ func TfheLte(t *testing.T, fheUintType fheUintType) { t.Fatalf("%d != %d", 0, res2.Uint64()) } } + func TfheLt(t *testing.T, fheUintType fheUintType) { var a, b big.Int switch fheUintType { @@ -205,16 +316,28 @@ func TfheLt(t *testing.T, fheUintType fheUintType) { } } -func TestTfheCksEncryptDecrypt8(t *testing.T) { - TfheCksEncryptDecrypt(t, FheUint8) +func TestTfheEncryptDecrypt8(t *testing.T) { + TfheEncryptDecrypt(t, FheUint8) +} + +func TestTfheEncryptDecrypt16(t *testing.T) { + TfheEncryptDecrypt(t, FheUint16) +} + +func TestTfheEncryptDecrypt32(t *testing.T) { + TfheEncryptDecrypt(t, FheUint32) } -func TestTfheCksEncryptDecrypt16(t *testing.T) { - TfheCksEncryptDecrypt(t, FheUint16) +func TestTfheTrivialEncryptDecrypt8(t *testing.T) { + TfheTrivialEncryptDecrypt(t, FheUint8) } -func TestTfheCksEncryptDecrypt32(t *testing.T) { - TfheCksEncryptDecrypt(t, FheUint32) +func TestTfheTrivialEncryptDecrypt16(t *testing.T) { + TfheTrivialEncryptDecrypt(t, FheUint16) +} + +func TestTfheTrivialEncryptDecrypt32(t *testing.T) { + TfheTrivialEncryptDecrypt(t, FheUint32) } func TestTfheSerializeDeserialize8(t *testing.T) { @@ -229,6 +352,30 @@ func TestTfheSerializeDeserialize32(t *testing.T) { TfheSerializeDeserialize(t, FheUint32) } +func TestTfheSerializeDeserializeCompact8(t *testing.T) { + TfheSerializeDeserializeCompact(t, FheUint8) +} + +func TestTfheSerializeDeserializeCompact16(t *testing.T) { + TfheSerializeDeserializeCompact(t, FheUint16) +} + +func TestTfheSerializeDeserializeCompact32(t *testing.T) { + TfheSerializeDeserializeCompact(t, FheUint32) +} + +func TestTfheTrivialSerializeDeserialize8(t *testing.T) { + TfheTrivialSerializeDeserialize(t, FheUint8) +} + +func TestTfheTrivialSerializeDeserialize16(t *testing.T) { + TfheTrivialSerializeDeserialize(t, FheUint16) +} + +func TestTfheTrivialSerializeDeserialize32(t *testing.T) { + TfheTrivialSerializeDeserialize(t, FheUint32) +} + func TestTfheDeserializeFailure8(t *testing.T) { TfheDeserializeFailure(t, FheUint8) } @@ -241,6 +388,30 @@ func TestTfheDeserializeFailure32(t *testing.T) { TfheDeserializeFailure(t, FheUint32) } +func TestTfheDeserializeCompact8(t *testing.T) { + TfheDeserializeCompact(t, FheUint8) +} + +func TestTfheDeserializeCompact16(t *testing.T) { + TfheDeserializeCompact(t, FheUint16) +} + +func TestTfheDeserializeCompatc32(t *testing.T) { + TfheDeserializeCompact(t, FheUint32) +} + +func TestTfheDeserializeCompactFailure8(t *testing.T) { + TfheDeserializeCompactFailure(t, FheUint8) +} + +func TestTfheDeserializeCompactFailure16(t *testing.T) { + TfheDeserializeCompactFailure(t, FheUint16) +} + +func TestTfheDeserializeCompatcFailure32(t *testing.T) { + TfheDeserializeCompactFailure(t, FheUint32) +} + func TestTfheAdd8(t *testing.T) { TfheAdd(t, FheUint8) } @@ -299,54 +470,3 @@ func TestTfheLte32(t *testing.T) { func TestTfheLt32(t *testing.T) { TfheLte(t, FheUint32) } - -// func TestTfheTrivialEncryptDecrypt(t *testing.T) { -// val := uint64(2) -// ct := new(tfheCiphertext) -// ct.trivialEncrypt(val) -// res := ct.decrypt() -// if res != val { -// t.Fatalf("%d != %d", val, res) -// } -// } - -// func TestTfheTrivialAndEncryptedLte(t *testing.T) { -// a := uint64(2) -// b := uint64(1) -// ctA := new(tfheCiphertext) -// ctA.encrypt(a) -// ctB := new(tfheCiphertext) -// ctB.trivialEncrypt(b) -// ctRes1 := ctA.lte(ctB) -// ctRes2 := ctB.lte(ctA) -// res1 := ctRes1.decrypt() -// res2 := ctRes2.decrypt() -// if res1 != 0 { -// t.Fatalf("%d != %d", 0, res1) -// } -// if res2 != 1 { -// t.Fatalf("%d != %d", 0, res2) -// } -// } - -// func TestTfheTrivialAndEncryptedAdd(t *testing.T) { -// a := uint64(1) -// b := uint64(1) -// ctA := new(tfheCiphertext) -// ctA.encrypt(a) -// ctB := new(tfheCiphertext) -// ctB.trivialEncrypt(b) -// ctRes := ctA.add(ctB) -// res := ctRes.decrypt() -// if res != 2 { -// t.Fatalf("%d != %d", 0, res) -// } -// } - -// func TestTfheTrivialSerializeSize(t *testing.T) { -// ct := new(tfheCiphertext) -// ct.trivialEncrypt(2) -// if len(ct.serialize()) != fheCiphertextSize { -// t.Fatalf("serialization of trivially encrypted unexpected size") -// } -// } diff --git a/install_thfe_rs_api.sh b/install_thfe_rs_api.sh index 8af191b302ed..8f3b0768f087 100755 --- a/install_thfe_rs_api.sh +++ b/install_thfe_rs_api.sh @@ -1,6 +1,7 @@ #!/bin/bash git clone https://github.com/zama-ai/tfhe-rs.git +git checkout 1d817c45d5234bcf33638406191b656998b30c2a mkdir -p core/vm/lib cd tfhe-rs make build_c_api