diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 56c1d626634e..ff7c05be25c9 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -77,7 +77,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -106,7 +106,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -136,7 +136,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -166,7 +166,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -196,7 +196,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{73}): &fheLt{}, // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, - // common.BytesToAddress([]byte{76}): &cast{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{77}): &trivialEncrypt{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -1438,7 +1438,7 @@ func (e *verifyCiphertext) RequiredGas(accessibleState PrecompileAccessibleState func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { logger := accessibleState.Interpreter().evm.Logger if len(input) <= 1 { - msg := "verifyCiphertext RequiredGas() input needs to contain a ciphertext and one byte for its type" + msg := "verifyCiphertext Run() input needs to contain a ciphertext and one byte for its type" logger.Error(msg, "len", len(input)) return nil, errors.New(msg) } @@ -2026,18 +2026,57 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add // return ctHash[:], nil // } -// type cast struct{} +type cast struct{} -// func (e *cast) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { -// return 0 -// } +func (e *cast) RequiredGas(accessibleState PrecompileAccessibleState, input []byte) uint64 { + if len(input) != 33 { + accessibleState.Interpreter().evm.Logger.Error( + "cast RequiredGas() input needs to contain a ciphertext and one byte for its type", + "len", len(input)) + return 0 + } + return params.FheCastGas +} -// // Implementation of the following is pending and will be completed once TFHE-rs add type casts to their high-level C API. -// func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { -// // var ctHandle = common.BytesToHash(input[0:31]) -// // var toType = input[32] -// return nil, nil -// } +// Implementation of the following is pending and will be completed once TFHE-rs add type casts to their high-level C API. +func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + logger := accessibleState.Interpreter().evm.Logger + if len(input) != 33 { + msg := "cast Run() input needs to contain a ciphertext and one byte for its type" + logger.Error(msg, "len", len(input)) + return nil, errors.New(msg) + } + + ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32])) + if ct == nil { + logger.Error("cast input not verified") + return nil, errors.New("unverified ciphertext handle") + } + + castToType := fheUintType(input[32]) + if !castToType.isValid() { + logger.Error("invalid type to cast to") + return nil, errors.New("invalid type provided") + } + + res, err := ct.ciphertext.castTo(castToType) + if err != nil { + msg := "cast Run() error casting ciphertext to" + logger.Error(msg, "type", castToType) + return nil, errors.New(msg) + } + + resHash := res.getHash() + + importCiphertext(accessibleState, res) + if accessibleState.Interpreter().evm.Commit { + logger.Info("cast success", + "ctHash", resHash.Hex(), + ) + } + + return resHash.Bytes(), nil +} type faucet struct{} diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index df62425358e5..00ba1cde5178 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -480,6 +480,66 @@ void public_key_encrypt_and_serialize_fhe_uint32_list(void* pks, uint32_t value, assert(r == 0); } +void* cast_8_16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_cast_into_fhe_uint16(ct, &result); + assert(r == 0); + return result; +} + +void* cast_8_32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint8_cast_into_fhe_uint32(ct, &result); + assert(r == 0); + return result; +} + +void* cast_16_8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_cast_into_fhe_uint8(ct, &result); + assert(r == 0); + return result; +} + +void* cast_16_32(void* ct, void* sks) { + FheUint32* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint16_cast_into_fhe_uint32(ct, &result); + assert(r == 0); + return result; +} + +void* cast_32_8(void* ct, void* sks) { + FheUint8* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_cast_into_fhe_uint8(ct, &result); + assert(r == 0); + return result; +} + +void* cast_32_16(void* ct, void* sks) { + FheUint16* result = NULL; + + checked_set_server_key(sks); + + const int r = fhe_uint32_cast_into_fhe_uint16(ct, &result); + assert(r == 0); + return result; +} + */ import "C" @@ -810,6 +870,49 @@ func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) { return res, nil } +func (ct *tfheCiphertext) castTo(castToType fheUintType) (*tfheCiphertext, error) { + if !ct.availableForOps() { + panic("cannot cast a non-initialized ciphertext") + } + + if ct.fheUintType == castToType { + return nil, errors.New("casting to same type is not supported") + } + + if !castToType.isValid() { + return nil, errors.New("invalid type to cast to") + } + + res := new(tfheCiphertext) + res.fheUintType = castToType + + switch ct.fheUintType { + case FheUint8: + switch castToType { + case FheUint16: + res.setPtr(C.cast_8_16(ct.ptr, sks)) + case FheUint32: + res.setPtr(C.cast_8_32(ct.ptr, sks)) + } + case FheUint16: + switch castToType { + case FheUint8: + res.setPtr(C.cast_16_8(ct.ptr, sks)) + case FheUint32: + res.setPtr(C.cast_16_32(ct.ptr, sks)) + } + case FheUint32: + switch castToType { + case FheUint8: + res.setPtr(C.cast_32_8(ct.ptr, sks)) + case FheUint16: + res.setPtr(C.cast_32_16(ct.ptr, sks)) + } + } + + return res, nil +} + func (ct *tfheCiphertext) decrypt() big.Int { if !ct.availableForOps() { panic("cannot decrypt a null ciphertext") @@ -869,6 +972,10 @@ func (ct *tfheCiphertext) initialized() bool { return (ct.ptr != nil) } +func (t *fheUintType) isValid() bool { + return (*t <= 2) +} + // Used for testing. func encryptAndSerializeCompact(value uint32, fheUintType fheUintType) []byte { out := &C.Buffer{} diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index e7d05ac6d7e4..58c02d709ac0 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -18,6 +18,7 @@ package vm import ( "bytes" + "math" "math/big" "testing" ) @@ -301,9 +302,9 @@ func TfheLt(t *testing.T, fheUintType fheUintType) { b.SetUint64(133337) } ctA := new(tfheCiphertext) - ctA.encrypt(a, FheUint8) + ctA.encrypt(a, fheUintType) ctB := new(tfheCiphertext) - ctB.encrypt(b, FheUint8) + ctB.encrypt(b, fheUintType) ctRes1, _ := ctA.lte(ctB) ctRes2, _ := ctB.lte(ctA) res1 := ctRes1.decrypt() @@ -316,6 +317,44 @@ func TfheLt(t *testing.T, fheUintType fheUintType) { } } +func TfheCast(t *testing.T, fheUintTypeFrom fheUintType, fheUintTypeTo fheUintType) { + var a big.Int + switch fheUintTypeFrom { + case FheUint8: + a.SetUint64(2) + case FheUint16: + a.SetUint64(4283) + case FheUint32: + a.SetUint64(1333337) + } + + var modulus uint64 + switch fheUintTypeTo { + case FheUint8: + modulus = uint64(math.Pow(2, 8)) + case FheUint16: + modulus = uint64(math.Pow(2, 16)) + case FheUint32: + modulus = uint64(math.Pow(2, 32)) + } + + ctA := new(tfheCiphertext) + ctA.encrypt(a, fheUintTypeFrom) + ctRes, err := ctA.castTo(fheUintTypeTo) + if err != nil { + t.Fatal(err) + } + + if ctRes.fheUintType != fheUintTypeTo { + t.Fatalf("type %d != type %d", ctA.fheUintType, fheUintTypeTo) + } + res := ctRes.decrypt() + expected := a.Uint64() % modulus + if res.Uint64() != expected { + t.Fatalf("%d != %d", res.Uint64(), expected) + } +} + func TestTfheEncryptDecrypt8(t *testing.T) { TfheEncryptDecrypt(t, FheUint8) } @@ -470,3 +509,27 @@ func TestTfheLte32(t *testing.T) { func TestTfheLt32(t *testing.T) { TfheLte(t, FheUint32) } + +func TestTfhe8Cast16(t *testing.T) { + TfheCast(t, FheUint8, FheUint16) +} + +func TestTfhe8Cast32(t *testing.T) { + TfheCast(t, FheUint8, FheUint32) +} + +func TestTfhe16Cast8(t *testing.T) { + TfheCast(t, FheUint16, FheUint8) +} + +func TestTfhe16Cast32(t *testing.T) { + TfheCast(t, FheUint16, FheUint32) +} + +func TestTfhe32Cast8(t *testing.T) { + TfheCast(t, FheUint16, FheUint8) +} + +func TestTfhe32Cast16(t *testing.T) { + TfheCast(t, FheUint16, FheUint8) +} diff --git a/params/protocol_params.go b/params/protocol_params.go index 4be3bb685483..73bc68de57f5 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -206,6 +206,8 @@ const ( FheUint16ProtectedStorageSloadGas uint64 = FheUint8ProtectedStorageSloadGas * 2 FheUint32ProtectedStorageSloadGas uint64 = FheUint16ProtectedStorageSloadGas * 4 + FheCastGas uint64 = 100 + FhePubKeyGas uint64 = 2 FheUint8TrivialEncryptGas uint64 = 100