Skip to content

Commit

Permalink
Add support for compact pub keys and ciphertexts (ethereum#108)
Browse files Browse the repository at this point in the history
The public key `pks` is now a `CompactPublicKey` in tfhe-rs terminology.

For inputs, we use compact lists from tfhe-rs. Essentially, transaction
inputs are serialized tfhe-rs compact lists. In the verifyCiphertext
precompile, we deserialize the list and then expand it to a normal
ciphertext. From that point on, nothing changes and we compute on and
persist expanded ciphertexts. Ciphertexts in protected storage are not
compact and are always persisted in their expanded form.

Use trivial encryption to determine expanded ciphertext size at startup.

Remove secret key encryption in tfhe_test.go - instead, only use public
key encryption. The `pks` key is expected to always be available to the
node. We will use that later for serving it to clients (potentially
over a special precompile for that purpose).

Add tests for the verifyCiphertext precompile.
Note: As of now, we cannot detect arbitrary ciphertexts during
deserialization time. Therefore, verifyCiphertext might or might not
fail and might produce random ciphertexts. Will work on that separately.

Generate keys using the zbc-fhe-tool.

Remove "random" tfheCiphertexts and, instead, always use public key
encryption during gas estimation. Rationale is that this ensures valid
ciphertexts in all code paths and it also ensures the actual ciphertext
is random-looking bytes.

Always persists ciphertexts in protected storage during opSstore, even
if the Commit flag is not set. Reason is that we want the same code
paths during gas estimation and actual transaction. If we skip
persisting, opSload will behave differently on gas estimation and
transactions.

Do not skip ciphertext verification on verifyCiphertext when the Commit
flag is not set. Rationale is, again, that we want the same code path
for gas estimation and transactions.
  • Loading branch information
dartdart26 authored Jun 15, 2023
1 parent db57ed7 commit 1bdc776
Show file tree
Hide file tree
Showing 8 changed files with 575 additions and 276 deletions.
27 changes: 12 additions & 15 deletions .github/workflows/publish_geth_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,43 +11,40 @@ jobs:
name: Build and test go-ethereum
runs-on: ubuntu-latest
steps:
- name: Check out tfhe-rs
- name: Checkout TFHE-rs
uses: actions/checkout@v3
with:
repository: zama-ai/tfhe-rs
ref: 0.2.1
ref: 1d817c45d5234bcf33638406191b656998b30c2a
path: tfhe-rs

- name: Check out tfhe-cli
- name: Checkout zbc-fhe-tool
uses: actions/checkout@v3
with:
repository: tremblaythibaultl/tfhe-cli
repository: zama-ai/zbc-fhe-tool
ref: main
path: tfhe-cli
token: ${{ secrets.CONCRETE_ACTIONS_TOKEN }}
path: zbc-fhe-tool

- name: Check out go-ethereum
- name: Checkout go-ethereum
uses: actions/checkout@v3
with:
path: go-ethereum

- name: Build C API
working-directory: ./tfhe-rs
run: make build_c_api
run: make build_c_api_experimental_deterministic_fft

- name: Move library files
run: |
mv ./tfhe-rs/target/release/tfhe.h ./go-ethereum/core/vm/
sudo mv ./tfhe-rs/target/release/libtfhe.* /usr/lib/
sudo mv ./tfhe-rs/target/release/tfhe.h /usr/include
sudo mv ./tfhe-rs/target/release/libtfhe.so /usr/lib/
- name: Generate TFHE-rs keys
working-directory: ./tfhe-cli
run: cargo run --release keygen bin .

- name: Move keys
working-directory: ./zbc-fhe-tool
run: |
mkdir -p $HOME/.evmosd/zama/keys/network-fhe-keys
mv ./tfhe-cli/client_key.bin $HOME/.evmosd/zama/keys/network-fhe-keys/cks
mv ./tfhe-cli/server_key.bin $HOME/.evmosd/zama/keys/network-fhe-keys/sks
cargo run --features tfhe/x86_64-unix --release -- generate-keys -d $HOME/.evmosd/zama/keys/network-fhe-keys
- name: Run tests
working-directory: ./go-ethereum/core/vm
Expand Down
17 changes: 7 additions & 10 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ func importCiphertext(accessibleState PrecompileAccessibleState, ct *tfheCiphert
// Used when we want to skip FHE computation, e.g. gas estimation.
func importRandomCiphertext(accessibleState PrecompileAccessibleState, t fheUintType) []byte {
ct := new(tfheCiphertext)
ct.makeRandom(t)
ct.encrypt(*big.NewInt(0), t)
importCiphertext(accessibleState, ct)
ctHash := ct.getHash()
return ctHash[:]
Expand Down Expand Up @@ -1435,13 +1435,8 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller
ctBytes := input[:len(input)-1]
ctType := fheUintType(input[len(input)-1])

// If we are doing gas estimation, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall {
return importRandomCiphertext(accessibleState, ctType), nil
}

ct := new(tfheCiphertext)
err := ct.deserialize(ctBytes, ctType)
err := ct.deserializeCompact(ctBytes, ctType)
if err != nil {
logger.Error("verifyCiphertext failed to deserialize input ciphertext",
"err", err,
Expand All @@ -1451,9 +1446,11 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller
}
ctHash := ct.getHash()
importCiphertext(accessibleState, ct)
logger.Info("verifyCiphertext success",
"ctHash", ctHash.Hex(),
"ctBytes64", hex.EncodeToString(ctBytes[:minInt(len(ctBytes), 64)]))
if accessibleState.Interpreter().evm.Commit {
logger.Info("verifyCiphertext success",
"ctHash", ctHash.Hex(),
"ctBytes64", hex.EncodeToString(ctBytes[:minInt(len(ctBytes), 64)]))
}
return ctHash.Bytes(), nil
}

Expand Down
134 changes: 111 additions & 23 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,13 @@ func newTestState() *statefulPrecompileAccessibleState {
}

func verifyCiphertextInTestMemory(interpreter *EVMInterpreter, value uint64, depth int, t fheUintType) *tfheCiphertext {
// Simulate as if the ciphertext is compact and comes externally.
ser := encryptAndSerializeCompact(uint32(value), t)
ct := new(tfheCiphertext)
ct.encrypt(*new(big.Int).SetUint64(value), t)
err := ct.deserializeCompact(ser, t)
if err != nil {
panic(err)
}
return verifyTfheCiphertextInTestMemory(interpreter, ct, depth)
}

Expand All @@ -455,6 +460,68 @@ func toPrecompileInput(hashes ...common.Hash) []byte {
return ret
}

func VerifyCiphertext(t *testing.T, fheUintType fheUintType) {
var value uint32
switch fheUintType {
case FheUint8:
value = 2
case FheUint16:
value = 4283
case FheUint32:
value = 1333337
}
c := &verifyCiphertext{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
compact := encryptAndSerializeCompact(value, fheUintType)
input := append(compact, byte(fheUintType))
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
ct := new(tfheCiphertext)
if err = ct.deserializeCompact(compact, fheUintType); err != nil {
t.Fatalf(err.Error())
}
if common.BytesToHash(out) != ct.getHash() {
t.Fatalf("output hash in verifyCipertext is incorrect")
}
res := getVerifiedCiphertextFromEVM(state.interpreter, ct.getHash())
if res == nil {
t.Fatalf("verifyCiphertext must have verified given ciphertext")
}
}

func VerifyCiphertextBadType(t *testing.T, actualType fheUintType, metadataType fheUintType) {
var value uint32
switch actualType {
case FheUint8:
value = 2
case FheUint16:
value = 4283
case FheUint32:
value = 1333337
}
c := &verifyCiphertext{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
compact := encryptAndSerializeCompact(value, actualType)
input := append(compact, byte(metadataType))
_, err := c.Run(state, addr, addr, input, readOnly)
if err == nil {
t.Fatalf("verifyCiphertext must have failed on type mismatch")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("verifyCiphertext mustn't have verified given ciphertext")
}
}

func FheAdd(t *testing.T, fheUintType fheUintType) {
var lhs, rhs uint64
switch fheUintType {
Expand Down Expand Up @@ -673,6 +740,49 @@ func FheLt(t *testing.T, fheUintType fheUintType) {
}
}

func TestVerifyCiphertext8(t *testing.T) {
VerifyCiphertext(t, FheUint8)
}

func TestVerifyCiphertext16(t *testing.T) {
VerifyCiphertext(t, FheUint16)
}

func TestVerifyCiphertext32(t *testing.T) {
VerifyCiphertext(t, FheUint32)
}

// func TestVerifyCiphertext8BadType(t *testing.T) {
// VerifyCiphertextBadType(t, FheUint8, FheUint16)
// VerifyCiphertextBadType(t, FheUint8, FheUint32)
// }

// func TestVerifyCiphertext16BadType(t *testing.T) {
// VerifyCiphertextBadType(t, FheUint16, FheUint8)
// VerifyCiphertextBadType(t, FheUint16, FheUint32)
// }

// func TestVerifyCiphertext32BadType(t *testing.T) {
// VerifyCiphertextBadType(t, FheUint32, FheUint8)
// VerifyCiphertextBadType(t, FheUint32, FheUint16)
// }

func TestVerifyCiphertextBadCiphertext(t *testing.T) {
c := &verifyCiphertext{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
addr := common.Address{}
readOnly := false
_, err := c.Run(state, addr, addr, make([]byte, 10), readOnly)
if err == nil {
t.Fatalf("verifyCiphertext must fail on bad ciphertext input")
}
if len(state.interpreter.verifiedCiphertexts) != 0 {
t.Fatalf("verifyCiphertext mustn't have verified given ciphertext")
}
}

func TestFheAdd8(t *testing.T) {
FheAdd(t, FheUint8)
}
Expand Down Expand Up @@ -733,28 +843,6 @@ func TestFheLt32(t *testing.T) {
FheLt(t, FheUint32)
}

// func TestFheRand(t *testing.T) {
// c := &fheRand{}
// depth := 1
// state := newTestState()
// state.interpreter.evm.depth = depth
// addr := common.Address{}
// readOnly := false

// out, err := c.Run(state, addr, addr, nil, readOnly)
// if err != nil {
// t.Fatalf(err.Error())
// }
// res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out))
// if res == nil {
// t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
// }
// decrypted := res.ciphertext.decrypt()
// if decrypted >= math.Pow(2, 3) {
// t.Fatalf("invalid decrypted result")
// }
// }

func TestUnknownCiphertextHandle(t *testing.T) {
depth := 1
state := newTestState()
Expand Down
59 changes: 32 additions & 27 deletions core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,18 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres
if metadataInt.IsZero() {
// If no metadata, it means this ciphertext itself hasn't been persisted to protected storage yet. We do that as part of SSTORE.
metadata.refCount = 1
metadata.length = uint64(fheCiphertextSize[verifiedCiphertext.ciphertext.fheUintType])
metadata.length = uint64(expandedFheCiphertextSize[verifiedCiphertext.ciphertext.fheUintType])
metadata.fheUintType = verifiedCiphertext.ciphertext.fheUintType
ciphertextSlot := newInt(val.Bytes())
ciphertextSlot.AddUint64(ciphertextSlot, 1)
logger.Info("opSstore persisting new ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"handle", hex.EncodeToString(val.Bytes()),
"type", metadata.fheUintType,
"len", metadata.length,
"ciphertextSlot", hex.EncodeToString(ciphertextSlot.Bytes()))
if interpreter.evm.Commit {
logger.Info("opSstore persisting new ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"handle", hex.EncodeToString(val.Bytes()),
"type", metadata.fheUintType,
"len", metadata.length,
"ciphertextSlot", hex.EncodeToString(ciphertextSlot.Bytes()))
}
ctPart32 := make([]byte, 32)
partIdx := 0
ctBytes := verifiedCiphertext.ciphertext.serialize()
Expand All @@ -650,12 +652,14 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres
// If metadata exists, bump the refcount by 1.
metadata = *newCiphertextMetadata(interpreter.evm.StateDB.GetState(protectedStorage, val))
metadata.refCount++
logger.Info("opSstore bumping refcount of existing ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"handle", hex.EncodeToString(val.Bytes()),
"type", metadata.fheUintType,
"len", metadata.length,
"refCount", metadata.refCount)
if interpreter.evm.Commit {
logger.Info("opSstore bumping refcount of existing ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"handle", hex.EncodeToString(val.Bytes()),
"type", metadata.fheUintType,
"len", metadata.length,
"refCount", metadata.refCount)
}
}
// Save the metadata in protected storage.
interpreter.evm.StateDB.SetState(protectedStorage, val, metadata.serialize())
Expand All @@ -669,11 +673,13 @@ func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage co
logger := interpreter.evm.Logger
metadata := newCiphertextMetadata(existingMetadataInt.Bytes32())
if metadata.refCount == 1 {
logger.Info("opSstore garbage-collecting ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"metadataKey", hex.EncodeToString(metadataKey[:]),
"type", metadata.fheUintType,
"len", metadata.length)
if interpreter.evm.Commit {
logger.Info("opSstore garbage-collecting ciphertext",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"metadataKey", hex.EncodeToString(metadataKey[:]),
"type", metadata.fheUintType,
"len", metadata.length)
}

// Zero the metadata key-value.
interpreter.evm.StateDB.SetState(protectedStorage, metadataKey, zero)
Expand All @@ -692,11 +698,13 @@ func garbageCollectProtectedStorage(metadataKey common.Hash, protectedStorage co
slot.AddUint64(slot, 1)
}
} else if metadata.refCount > 1 {
logger.Info("opSstore decrementing ciphertext refCount",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"metadataKey", hex.EncodeToString(metadataKey[:]),
"type", metadata.fheUintType,
"len", metadata.length)
if interpreter.evm.Commit {
logger.Info("opSstore decrementing ciphertext refCount",
"protectedStorage", hex.EncodeToString(protectedStorage[:]),
"metadataKey", hex.EncodeToString(metadataKey[:]),
"type", metadata.fheUintType,
"len", metadata.length)
}
metadata.refCount--
interpreter.evm.StateDB.SetState(protectedStorage, existingMetadataHash, metadata.serialize())
}
Expand All @@ -713,10 +721,7 @@ func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]b
newValHash := common.BytesToHash(newValBytes)
oldValHash := interpreter.evm.StateDB.GetState(scope.Contract.Address(), common.Hash(loc.Bytes32()))
protectedStorage := crypto.CreateProtectedStorageContractAddress(scope.Contract.Address())
// Here, we assume that if the `Commit` flag is not set, no precompile would read/write an actual ciphertext
// from/to protected storage. Instead, all precompiles just insert random ciphertexts to memory.
// Therefore, if `Commit` is not set, we don't need to touch protected storage at all.
if interpreter.evm.Commit && newValHash != oldValHash {
if newValHash != oldValHash {
// Since the old value is no longer stored in actual contract storage, run garbage collection on protected storage.
garbageCollectProtectedStorage(oldValHash, protectedStorage, interpreter)
// If a verified ciphertext, persist to protected storage.
Expand Down
8 changes: 4 additions & 4 deletions core/vm/instructions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ func (c testCallerAddress) Address() common.Address {
func newTestScopeConext() *ScopeContext {
c := new(ScopeContext)
c.Memory = NewMemory()
c.Memory.Resize(uint64(fheCiphertextSize[FheUint8]) * 3)
c.Memory.Resize(uint64(expandedFheCiphertextSize[FheUint8]) * 3)
c.Stack = newstack()
c.Contract = NewContract(testCallerAddress{}, testContractAddress{}, big.NewInt(10), 100000)
return c
Expand All @@ -732,7 +732,7 @@ func TestProtectedStorageSstoreSload(t *testing.T) {
depth := 1
interpreter := newTestInterpreter()
interpreter.evm.depth = depth
ct := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8)
ct := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint32)
ctHash := ct.getHash()
scope := newTestScopeConext()
loc := uint256.NewInt(10)
Expand Down Expand Up @@ -791,8 +791,8 @@ func TestProtectedStorageGarbageCollection(t *testing.T) {
if metadata.refCount != 1 {
t.Fatalf("metadata.refcount of ciphertext is not 1")
}
if metadata.length != uint64(fheCiphertextSize[FheUint8]) {
t.Fatalf("metadata.length (%v) != ciphertext len (%v)", metadata.length, uint64(fheCiphertextSize[FheUint8]))
if metadata.length != uint64(expandedFheCiphertextSize[FheUint8]) {
t.Fatalf("metadata.length (%v) != ciphertext len (%v)", metadata.length, uint64(expandedFheCiphertextSize[FheUint8]))
}
ciphertextLocationsToCheck := (metadata.length + 32 - 1) / 32
startOfCiphertext := newInt(ctHash[:])
Expand Down
Loading

0 comments on commit 1bdc776

Please sign in to comment.