Skip to content

Commit

Permalink
Add dedicated struct to hold the special type IDs from the EVM contract
Browse files Browse the repository at this point in the history
  • Loading branch information
m-Peter committed Oct 10, 2024
1 parent e8694e9 commit d004ec5
Showing 1 changed file with 61 additions and 44 deletions.
105 changes: 61 additions & 44 deletions fvm/evm/impl/abi.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,32 @@ func (e abiDecodingError) Error() string {
return b.String()
}

type evmSpecialTypeIDs struct {
AddressTypeID common.TypeID
BytesTypeID common.TypeID
Bytes4TypeID common.TypeID
Bytes32TypeID common.TypeID
}

func NewEVMSpecialTypeIDs(
gauge common.MemoryGauge,
location common.AddressLocation,
) evmSpecialTypeIDs {
return evmSpecialTypeIDs{
AddressTypeID: location.TypeID(gauge, stdlib.EVMAddressTypeQualifiedIdentifier),
BytesTypeID: location.TypeID(gauge, stdlib.EVMBytesTypeQualifiedIdentifier),
Bytes4TypeID: location.TypeID(gauge, stdlib.EVMBytes4TypeQualifiedIdentifier),
Bytes32TypeID: location.TypeID(gauge, stdlib.EVMBytes32TypeQualifiedIdentifier),
}
}

func reportABIEncodingComputation(
inter *interpreter.Interpreter,
locationRange interpreter.LocationRange,
values *interpreter.ArrayValue,
evmLocation common.AddressLocation,
evmTypeIDs evmSpecialTypeIDs,
reportComputation func(intensity uint),
) {
evmAddressTypeID := evmLocation.TypeID(inter, stdlib.EVMAddressTypeQualifiedIdentifier)
evmBytesTypeID := evmLocation.TypeID(inter, stdlib.EVMBytesTypeQualifiedIdentifier)
evmBytes4TypeID := evmLocation.TypeID(inter, stdlib.EVMBytes4TypeQualifiedIdentifier)
evmBytes32TypeID := evmLocation.TypeID(inter, stdlib.EVMBytes32TypeQualifiedIdentifier)

values.Iterate(
inter,
Expand Down Expand Up @@ -129,14 +144,14 @@ func reportABIEncodingComputation(

case *interpreter.CompositeValue:
switch value.TypeID() {
case evmAddressTypeID:
case evmTypeIDs.AddressTypeID:
// EVM addresses are static variables with a fixed
// size of 32 bytes.
reportComputation(abiEncodingByteSize)

case evmBytesTypeID,
evmBytes4TypeID,
evmBytes32TypeID:
case evmTypeIDs.BytesTypeID,
evmTypeIDs.Bytes4TypeID,
evmTypeIDs.Bytes32TypeID:

computation := uint(2 * abiEncodingByteSize)
bytesArrayValue := value.GetMember(inter, locationRange, stdlib.EVMBytesTypeValueFieldName)
Expand Down Expand Up @@ -165,7 +180,7 @@ func reportABIEncodingComputation(
inter,
locationRange,
value,
evmLocation,
evmTypeIDs,
reportComputation,
)

Expand All @@ -188,7 +203,7 @@ func newInternalEVMTypeEncodeABIFunction(
location common.AddressLocation,
) *interpreter.HostFunctionValue {

evmAddressTypeID := location.TypeID(gauge, stdlib.EVMAddressTypeQualifiedIdentifier)
evmSpecialTypeIDs := NewEVMSpecialTypeIDs(gauge, location)

return interpreter.NewStaticHostFunctionValue(
gauge,
Expand All @@ -208,7 +223,7 @@ func newInternalEVMTypeEncodeABIFunction(
inter,
locationRange,
valuesArray,
location,
evmSpecialTypeIDs,
func(intensity uint) {
inter.ReportComputation(environment.ComputationKindEVMEncodeABI, intensity)
},
Expand All @@ -227,7 +242,7 @@ func newInternalEVMTypeEncodeABIFunction(
locationRange,
element,
element.StaticType(inter),
evmAddressTypeID,
evmSpecialTypeIDs,
)
if err != nil {
panic(err)
Expand Down Expand Up @@ -295,7 +310,7 @@ var gethTypeBytes32 = gethABI.Type{T: gethABI.BytesTy, Size: 32}

func gethABIType(
staticType interpreter.StaticType,
evmAddressTypeID common.TypeID,
evmTypeIDs evmSpecialTypeIDs,
) (gethABI.Type, bool) {
switch staticType {
case interpreter.PrimitiveStaticTypeString:
Expand Down Expand Up @@ -336,26 +351,26 @@ func gethABIType(

switch staticType := staticType.(type) {
case *interpreter.CompositeStaticType:
if staticType.TypeID == evmAddressTypeID {
if staticType.TypeID == evmTypeIDs.AddressTypeID {
return gethTypeAddress, true
}

if staticType.TypeID == "A.0000000000000001.EVM.EVMBytes" {
if staticType.TypeID == evmTypeIDs.BytesTypeID {
return gethTypeBytes, true
}

if staticType.TypeID == "A.0000000000000001.EVM.EVMBytes4" {
if staticType.TypeID == evmTypeIDs.Bytes4TypeID {
return gethTypeBytes4, true
}

if staticType.TypeID == "A.0000000000000001.EVM.EVMBytes32" {
if staticType.TypeID == evmTypeIDs.Bytes32TypeID {
return gethTypeBytes32, true
}

case *interpreter.ConstantSizedStaticType:
elementGethABIType, ok := gethABIType(
staticType.ElementType(),
evmAddressTypeID,
evmTypeIDs,
)
if !ok {
break
Expand All @@ -370,7 +385,7 @@ func gethABIType(
case *interpreter.VariableSizedStaticType:
elementGethABIType, ok := gethABIType(
staticType.ElementType(),
evmAddressTypeID,
evmTypeIDs,
)
if !ok {
break
Expand All @@ -388,7 +403,7 @@ func gethABIType(

func goType(
staticType interpreter.StaticType,
evmAddressTypeID common.TypeID,
evmTypeIDs evmSpecialTypeIDs,
) (reflect.Type, bool) {
switch staticType {
case interpreter.PrimitiveStaticTypeString:
Expand Down Expand Up @@ -429,35 +444,35 @@ func goType(

switch staticType := staticType.(type) {
case *interpreter.ConstantSizedStaticType:
elementType, ok := goType(staticType.ElementType(), evmAddressTypeID)
elementType, ok := goType(staticType.ElementType(), evmTypeIDs)
if !ok {
break
}

return reflect.ArrayOf(int(staticType.Size), elementType), true

case *interpreter.VariableSizedStaticType:
elementType, ok := goType(staticType.ElementType(), evmAddressTypeID)
elementType, ok := goType(staticType.ElementType(), evmTypeIDs)
if !ok {
break
}

return reflect.SliceOf(elementType), true
}

if staticType.ID() == evmAddressTypeID {
if staticType.ID() == evmTypeIDs.AddressTypeID {
return reflect.TypeOf(gethCommon.Address{}), true
}

if staticType.ID() == "A.0000000000000001.EVM.EVMBytes" {
if staticType.ID() == evmTypeIDs.BytesTypeID {
return reflect.SliceOf(reflect.TypeOf(byte(0))), true
}

if staticType.ID() == "A.0000000000000001.EVM.EVMBytes4" {
if staticType.ID() == evmTypeIDs.Bytes4TypeID {
return reflect.SliceOf(reflect.TypeOf(byte(0))), true
}

if staticType.ID() == "A.0000000000000001.EVM.EVMBytes32" {
if staticType.ID() == evmTypeIDs.Bytes32TypeID {
return reflect.SliceOf(reflect.TypeOf(byte(0))), true
}

Expand All @@ -469,7 +484,7 @@ func encodeABI(
locationRange interpreter.LocationRange,
value interpreter.Value,
staticType interpreter.StaticType,
evmAddressTypeID common.TypeID,
evmTypeIDs evmSpecialTypeIDs,
) (
any,
gethABI.Type,
Expand Down Expand Up @@ -570,7 +585,9 @@ func encodeABI(
}

case *interpreter.CompositeValue:
if value.TypeID() == evmAddressTypeID {
typeID := value.TypeID()

if typeID == evmTypeIDs.AddressTypeID {
addressBytesArrayValue := value.GetMember(inter, locationRange, stdlib.EVMAddressTypeBytesFieldName)
bytes, err := interpreter.ByteArrayValueToByteSlice(
inter,
Expand All @@ -584,7 +601,7 @@ func encodeABI(
return gethCommon.Address(bytes), gethTypeAddress, nil
}

if value.TypeID() == "A.0000000000000001.EVM.EVMBytes" {
if typeID == evmTypeIDs.BytesTypeID {
bytesArrayValue := value.GetMember(inter, locationRange, "value")
bytes, err := interpreter.ByteArrayValueToByteSlice(
inter,
Expand All @@ -598,7 +615,7 @@ func encodeABI(
return bytes, gethTypeBytes, nil
}

if value.TypeID() == "A.0000000000000001.EVM.EVMBytes4" {
if typeID == evmTypeIDs.Bytes4TypeID {
bytesArrayValue := value.GetMember(inter, locationRange, "value")
bytes, err := interpreter.ByteArrayValueToByteSlice(
inter,
Expand All @@ -612,7 +629,7 @@ func encodeABI(
return bytes, gethTypeBytes4, nil
}

if value.TypeID() == "A.0000000000000001.EVM.EVMBytes32" {
if typeID == evmTypeIDs.Bytes32TypeID {
bytesArrayValue := value.GetMember(inter, locationRange, "value")
bytes, err := interpreter.ByteArrayValueToByteSlice(
inter,
Expand All @@ -629,14 +646,14 @@ func encodeABI(
case *interpreter.ArrayValue:
arrayStaticType := value.Type

arrayGethABIType, ok := gethABIType(arrayStaticType, evmAddressTypeID)
arrayGethABIType, ok := gethABIType(arrayStaticType, evmTypeIDs)
if !ok {
break
}

elementStaticType := arrayStaticType.ElementType()

elementGoType, ok := goType(elementStaticType, evmAddressTypeID)
elementGoType, ok := goType(elementStaticType, evmTypeIDs)
if !ok {
break
}
Expand All @@ -663,7 +680,7 @@ func encodeABI(
locationRange,
element,
element.StaticType(inter),
evmAddressTypeID,
evmTypeIDs,
)
if err != nil {
panic(err)
Expand Down Expand Up @@ -691,7 +708,7 @@ func newInternalEVMTypeDecodeABIFunction(
gauge common.MemoryGauge,
location common.AddressLocation,
) *interpreter.HostFunctionValue {
evmAddressTypeID := location.TypeID(gauge, stdlib.EVMAddressTypeQualifiedIdentifier)
evmSpecialTypeIDs := NewEVMSpecialTypeIDs(gauge, location)

return interpreter.NewStaticHostFunctionValue(
gauge,
Expand Down Expand Up @@ -735,7 +752,7 @@ func newInternalEVMTypeDecodeABIFunction(

staticType := typeValue.Type

gethABITy, ok := gethABIType(staticType, evmAddressTypeID)
gethABITy, ok := gethABIType(staticType, evmSpecialTypeIDs)
if !ok {
panic(abiDecodingError{
Type: staticType,
Expand Down Expand Up @@ -780,7 +797,7 @@ func newInternalEVMTypeDecodeABIFunction(
decodedValues[index],
staticType,
location,
evmAddressTypeID,
evmSpecialTypeIDs,
)
if err != nil {
panic(err)
Expand Down Expand Up @@ -822,7 +839,7 @@ func decodeABI(
value any,
staticType interpreter.StaticType,
location common.AddressLocation,
evmAddressTypeID common.TypeID,
evmTypeIDs evmSpecialTypeIDs,
) (
interpreter.Value,
error,
Expand Down Expand Up @@ -981,7 +998,7 @@ func decodeABI(
element,
elementStaticType,
location,
evmAddressTypeID,
evmTypeIDs,
)
if err != nil {
panic(err)
Expand All @@ -994,7 +1011,7 @@ func decodeABI(
), nil

case *interpreter.CompositeStaticType:
if staticType.TypeID == evmAddressTypeID {
if staticType.TypeID == evmTypeIDs.AddressTypeID {
addr, ok := value.(gethCommon.Address)
if !ok {
break
Expand All @@ -1010,7 +1027,7 @@ func decodeABI(
), nil
}

if staticType.TypeID == "A.0000000000000001.EVM.EVMBytes" {
if staticType.TypeID == evmTypeIDs.BytesTypeID {
bytes, ok := value.([]uint8)
if !ok {
break
Expand All @@ -1023,7 +1040,7 @@ func decodeABI(
), nil
}

if staticType.TypeID == "A.0000000000000001.EVM.EVMBytes4" {
if staticType.TypeID == evmTypeIDs.Bytes4TypeID {
bytes, ok := value.([]uint8)
if !ok {
break
Expand All @@ -1036,7 +1053,7 @@ func decodeABI(
), nil
}

if staticType.TypeID == "A.0000000000000001.EVM.EVMBytes32" {
if staticType.TypeID == evmTypeIDs.Bytes32TypeID {
bytes, ok := value.([]uint8)
if !ok {
break
Expand Down

0 comments on commit d004ec5

Please sign in to comment.