From df1233c3a2d2063e52f255dc41ac3f1d6a47d2d5 Mon Sep 17 00:00:00 2001 From: Kelvin Fichter Date: Fri, 4 Oct 2024 19:52:10 -0400 Subject: [PATCH] feat(ct): port interface checks to go Ports the interface checks command to go. Updates build command to execute forge fmt and the interface checks. --- packages/contracts-bedrock/justfile | 15 +- .../scripts/checks/interfaces/main.go | 354 ++++++++++++++++++ .../scripts/checks/interfaces/main_test.go | 295 +++++++++++++++ 3 files changed, 661 insertions(+), 3 deletions(-) create mode 100644 packages/contracts-bedrock/scripts/checks/interfaces/main.go create mode 100644 packages/contracts-bedrock/scripts/checks/interfaces/main_test.go diff --git a/packages/contracts-bedrock/justfile b/packages/contracts-bedrock/justfile index 42ab97fcd509f..2b29922ca45ce 100644 --- a/packages/contracts-bedrock/justfile +++ b/packages/contracts-bedrock/justfile @@ -19,10 +19,13 @@ dep-status: prebuild: ./scripts/checks/check-foundry-install.sh -# Builds the contracts. -build: prebuild +# Core forge build command +forge-build: forge build +# Builds the contracts. +build: prebuild lint-fix-no-fail forge-build interfaces-check-no-build + # Builds the go-ffi tool for contract tests. build-go-ffi: cd scripts/go-ffi && go build @@ -137,7 +140,7 @@ snapshots-check: # Checks interface correctness without building. interfaces-check-no-build: - ./scripts/checks/check-interfaces.sh + go run ./scripts/checks/interfaces # Checks that all interfaces are appropriately named and accurately reflect the corresponding # contract that they're meant to represent. We run "clean" before building because leftover @@ -219,5 +222,11 @@ pre-pr-no-build: build-go-ffi build lint gas-snapshot-no-build snapshots-no-buil lint-fix: forge fmt +# Fixes linting errors but doesn't fail if there are syntax errors. Useful for build command +# because the output of forge fmt can sometimes be difficult to understand but if there's a syntax +# error the build will fail anyway and provide more context about what's wrong. +lint-fix-no-fail: + forge fmt || true + # Fixes linting errors and checks that the code is correctly formatted. lint: lint-fix lint-check diff --git a/packages/contracts-bedrock/scripts/checks/interfaces/main.go b/packages/contracts-bedrock/scripts/checks/interfaces/main.go new file mode 100644 index 0000000000000..004f1d5c03cdc --- /dev/null +++ b/packages/contracts-bedrock/scripts/checks/interfaces/main.go @@ -0,0 +1,354 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "sync" + "sync/atomic" + + "github.com/google/go-cmp/cmp" +) + +var excludeContracts = []string{ + // External dependencies + "IERC20", "IERC721", "IERC721Enumerable", "IERC721Upgradeable", "IERC721Metadata", + "IERC165", "IERC165Upgradeable", "ERC721TokenReceiver", "ERC1155TokenReceiver", + "ERC777TokensRecipient", "Guard", "IProxy", "Vm", "VmSafe", "IMulticall3", + "IERC721TokenReceiver", "IProxyCreationCallback", "IBeacon", + + // EAS + "IEAS", "ISchemaResolver", "ISchemaRegistry", + + // TODO: Interfaces that need to be fixed + "IInitializable", "IPreimageOracle", "ILegacyMintableERC20", "IOptimismMintableERC20", + "IOptimismMintableERC721", "IOptimismSuperchainERC20", "MintableAndBurnable", + "KontrolCheatsBase", "IWETH", "IDelayedWETH", "IL2ToL2CrossDomainMessenger", + "ICrossL2Inbox", "ISystemConfigInterop", "IResolvedDelegateProxy", +} + +type ContractDefinition struct { + ContractKind string `json:"contractKind"` + Name string `json:"name"` +} + +type ASTNode struct { + NodeType string `json:"nodeType"` + Literals []string `json:"literals,omitempty"` + ContractDefinition +} + +type ArtifactAST struct { + Nodes []ASTNode `json:"nodes"` +} + +type Artifact struct { + AST ArtifactAST `json:"ast"` + ABI json.RawMessage `json:"abi"` +} + +func main() { + if err := run(); err != nil { + writeStderr("an error occurred: %v", err) + os.Exit(1) + } +} + +func writeStderr(msg string, args ...any) { + _, _ = fmt.Fprintf(os.Stderr, msg+"\n", args...) +} + +func run() error { + cwd, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current working directory: %w", err) + } + + artifactsDir := filepath.Join(cwd, "forge-artifacts") + + artifactFiles, err := glob(artifactsDir, ".json") + if err != nil { + return fmt.Errorf("failed to get artifact files: %w", err) + } + + // Remove duplicates from artifactFiles + uniqueArtifacts := make(map[string]string) + for contractName, artifactPath := range artifactFiles { + baseName := strings.Split(contractName, ".")[0] + uniqueArtifacts[baseName] = artifactPath + } + + var hasErr int32 + var outMtx sync.Mutex + fail := func(msg string, args ...any) { + outMtx.Lock() + writeStderr("❌ "+msg, args...) + outMtx.Unlock() + atomic.StoreInt32(&hasErr, 1) + } + + sem := make(chan struct{}, runtime.NumCPU()) + for contractName, artifactPath := range uniqueArtifacts { + contractName := contractName + artifactPath := artifactPath + + sem <- struct{}{} + + go func() { + defer func() { + <-sem + }() + + if err := processArtifact(contractName, artifactPath, artifactsDir, fail); err != nil { + fail("%s: %v", contractName, err) + } + }() + } + + for i := 0; i < cap(sem); i++ { + sem <- struct{}{} + } + + if atomic.LoadInt32(&hasErr) == 1 { + return errors.New("interface check failed, see logs above") + } + + return nil +} + +func processArtifact(contractName, artifactPath, artifactsDir string, fail func(string, ...any)) error { + if isExcluded(contractName) { + return nil + } + + artifact, err := readArtifact(artifactPath) + if err != nil { + return fmt.Errorf("failed to read artifact: %w", err) + } + + contractDef := getContractDefinition(artifact, contractName) + if contractDef == nil { + return nil // Skip processing if contract definition is not found + } + + if contractDef.ContractKind != "interface" { + return nil + } + + if !strings.HasPrefix(contractName, "I") { + fail("%s: Interface does not start with 'I'", contractName) + } + + semver, err := getContractSemver(artifact) + if err != nil { + return err + } + + if semver != "solidity^0.8.0" { + fail("%s: Interface does not have correct compiler version (MUST be exactly solidity ^0.8.0)", contractName) + } + + contractBasename := contractName[1:] + correspondingContractFile := filepath.Join(artifactsDir, contractBasename+".sol", contractBasename+".json") + + if _, err := os.Stat(correspondingContractFile); errors.Is(err, os.ErrNotExist) { + return nil + } + + contractArtifact, err := readArtifact(correspondingContractFile) + if err != nil { + return fmt.Errorf("failed to read corresponding contract artifact: %w", err) + } + + interfaceABI := artifact.ABI + contractABI := contractArtifact.ABI + + normalizedInterfaceABI, err := normalizeABI(interfaceABI) + if err != nil { + return fmt.Errorf("failed to normalize interface ABI: %w", err) + } + + normalizedContractABI, err := normalizeABI(contractABI) + if err != nil { + return fmt.Errorf("failed to normalize contract ABI: %w", err) + } + + match, err := compareABIs(normalizedInterfaceABI, normalizedContractABI) + if err != nil { + return fmt.Errorf("failed to compare ABIs: %w", err) + } + if !match { + fail("%s: Differences found in ABI between interface and actual contract", contractName) + } + + return nil +} + +func readArtifact(path string) (*Artifact, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open artifact file: %w", err) + } + defer file.Close() + + var artifact Artifact + if err := json.NewDecoder(file).Decode(&artifact); err != nil { + return nil, fmt.Errorf("failed to parse artifact file: %w", err) + } + + return &artifact, nil +} + +func getContractDefinition(artifact *Artifact, contractName string) *ContractDefinition { + for _, node := range artifact.AST.Nodes { + if node.NodeType == "ContractDefinition" && node.Name == contractName { + return &node.ContractDefinition + } + } + return nil +} + +func getContractSemver(artifact *Artifact) (string, error) { + for _, node := range artifact.AST.Nodes { + if node.NodeType == "PragmaDirective" { + return strings.Join(node.Literals, ""), nil + } + } + return "", errors.New("semver not found") +} + +func normalizeABI(abi json.RawMessage) (json.RawMessage, error) { + var abiData []map[string]interface{} + if err := json.Unmarshal(abi, &abiData); err != nil { + return nil, err + } + + hasConstructor := false + for i := range abiData { + normalizeABIItem(abiData[i]) + if abiData[i]["type"] == "constructor" { + hasConstructor = true + } + } + + // Add an empty constructor if it doesn't exist + if !hasConstructor { + emptyConstructor := map[string]interface{}{ + "type": "constructor", + "stateMutability": "nonpayable", + "inputs": []interface{}{}, + } + abiData = append(abiData, emptyConstructor) + } + + return json.Marshal(abiData) +} + +func normalizeABIItem(item map[string]interface{}) { + for key, value := range item { + switch v := value.(type) { + case string: + if key == "internalType" { + item[key] = normalizeInternalType(v) + } + case map[string]interface{}: + normalizeABIItem(v) + case []interface{}: + for _, elem := range v { + if elemMap, ok := elem.(map[string]interface{}); ok { + normalizeABIItem(elemMap) + } + } + } + } + + if item["type"] == "function" && item["name"] == "__constructor__" { + item["type"] = "constructor" + delete(item, "name") + delete(item, "outputs") + } +} + +func normalizeInternalType(internalType string) string { + internalType = strings.ReplaceAll(internalType, "contract I", "contract ") + internalType = strings.ReplaceAll(internalType, "enum I", "enum ") + internalType = strings.ReplaceAll(internalType, "struct I", "struct ") + return internalType +} + +func compareABIs(abi1, abi2 json.RawMessage) (bool, error) { + var data1, data2 []map[string]interface{} + + if err := json.Unmarshal(abi1, &data1); err != nil { + return false, fmt.Errorf("error unmarshalling first ABI: %w", err) + } + + if err := json.Unmarshal(abi2, &data2); err != nil { + return false, fmt.Errorf("error unmarshalling second ABI: %w", err) + } + + // Sort the ABI data + sort.Slice(data1, func(i, j int) bool { + return abiItemLess(data1[i], data1[j]) + }) + sort.Slice(data2, func(i, j int) bool { + return abiItemLess(data2[i], data2[j]) + }) + + // Compare using go-cmp + diff := cmp.Diff(data1, data2) + if diff != "" { + return false, nil + } + return true, nil +} + +func abiItemLess(a, b map[string]interface{}) bool { + aType := getString(a, "type") + bType := getString(b, "type") + + if aType != bType { + return aType < bType + } + + aName := getString(a, "name") + bName := getString(b, "name") + return aName < bName +} + +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func isExcluded(contractName string) bool { + for _, exclude := range excludeContracts { + if exclude == contractName { + return true + } + } + return false +} + +func glob(dir string, ext string) (map[string]string, error) { + out := make(map[string]string) + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if !info.IsDir() && filepath.Ext(path) == ext { + out[strings.TrimSuffix(filepath.Base(path), ext)] = path + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to walk directory: %w", err) + } + return out, nil +} diff --git a/packages/contracts-bedrock/scripts/checks/interfaces/main_test.go b/packages/contracts-bedrock/scripts/checks/interfaces/main_test.go new file mode 100644 index 0000000000000..d1c9237e47223 --- /dev/null +++ b/packages/contracts-bedrock/scripts/checks/interfaces/main_test.go @@ -0,0 +1,295 @@ +package main + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestGetContractDefinition(t *testing.T) { + artifact := &Artifact{ + AST: ArtifactAST{ + Nodes: []ASTNode{ + {NodeType: "ContractDefinition", ContractDefinition: ContractDefinition{ContractKind: "interface", Name: "ITest"}}, + {NodeType: "ContractDefinition", ContractDefinition: ContractDefinition{ContractKind: "contract", Name: "Test"}}, + {NodeType: "ContractDefinition", ContractDefinition: ContractDefinition{ContractKind: "library", Name: "TestLib"}}, + }, + }, + } + + tests := []struct { + name string + contractName string + want *ContractDefinition + }{ + {"Find interface", "ITest", &ContractDefinition{ContractKind: "interface", Name: "ITest"}}, + {"Find contract", "Test", &ContractDefinition{ContractKind: "contract", Name: "Test"}}, + {"Find library", "TestLib", &ContractDefinition{ContractKind: "library", Name: "TestLib"}}, + {"Not found", "NonExistent", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := getContractDefinition(artifact, tt.contractName) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getContractDefinition() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetContractSemver(t *testing.T) { + tests := []struct { + name string + artifact *Artifact + want string + wantErr bool + }{ + { + name: "Valid semver", + artifact: &Artifact{ + AST: ArtifactAST{ + Nodes: []ASTNode{ + {NodeType: "PragmaDirective", Literals: []string{"solidity", "^", "0.8.0"}}, + }, + }, + }, + want: "solidity^0.8.0", + wantErr: false, + }, + { + name: "Multiple pragmas", + artifact: &Artifact{ + AST: ArtifactAST{ + Nodes: []ASTNode{ + {NodeType: "PragmaDirective", Literals: []string{"solidity", "^", "0.8.0"}}, + {NodeType: "PragmaDirective", Literals: []string{"abicoder", "v2"}}, + }, + }, + }, + want: "solidity^0.8.0", + wantErr: false, + }, + { + name: "No semver", + artifact: &Artifact{ + AST: ArtifactAST{ + Nodes: []ASTNode{ + {NodeType: "ContractDefinition"}, + }, + }, + }, + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getContractSemver(tt.artifact) + if (err != nil) != tt.wantErr { + t.Errorf("getContractSemver() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getContractSemver() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNormalizeABI(t *testing.T) { + tests := []struct { + name string + abi string + want string + }{ + { + name: "Replace interface types and add constructor", + abi: `[{"inputs":[{"internalType":"contract ITest","name":"test","type":"address"}],"type":"function"}]`, + want: `[{"inputs":[{"internalType":"contract Test","name":"test","type":"address"}],"type":"function"},{"inputs":[],"stateMutability":"nonpayable","type":"constructor"}]`, + }, + { + name: "Convert __constructor__", + abi: `[{"type":"function","name":"__constructor__","inputs":[],"stateMutability":"nonpayable","outputs":[]}]`, + want: `[{"type":"constructor","inputs":[],"stateMutability":"nonpayable"}]`, + }, + { + name: "Keep existing constructor", + abi: `[{"type":"constructor","inputs":[{"name":"param","type":"uint256"}]},{"type":"function","name":"test"}]`, + want: `[{"type":"constructor","inputs":[{"name":"param","type":"uint256"}]},{"type":"function","name":"test"}]`, + }, + { + name: "Replace multiple interface types", + abi: `[{"inputs":[{"internalType":"contract ITest1","name":"test1","type":"address"},{"internalType":"contract ITest2","name":"test2","type":"address"}],"type":"function"}]`, + want: `[{"inputs":[{"internalType":"contract Test1","name":"test1","type":"address"},{"internalType":"contract Test2","name":"test2","type":"address"}],"type":"function"},{"inputs":[],"stateMutability":"nonpayable","type":"constructor"}]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeABI(json.RawMessage(tt.abi)) + if err != nil { + t.Errorf("normalizeABI() error = %v", err) + return + } + var gotJSON, wantJSON interface{} + if err := json.Unmarshal(got, &gotJSON); err != nil { + t.Errorf("Error unmarshalling got JSON: %v", err) + return + } + if err := json.Unmarshal([]byte(tt.want), &wantJSON); err != nil { + t.Errorf("Error unmarshalling want JSON: %v", err) + return + } + if !reflect.DeepEqual(gotJSON, wantJSON) { + t.Errorf("normalizeABI() = %v, want %v", string(got), tt.want) + } + }) + } +} + +func TestCompareABIs(t *testing.T) { + tests := []struct { + name string + abi1 string + abi2 string + want bool + }{ + { + name: "Identical ABIs", + abi1: `[{"type":"function","name":"test","inputs":[],"outputs":[]}]`, + abi2: `[{"type":"function","name":"test","inputs":[],"outputs":[]}]`, + want: true, + }, + { + name: "Different ABIs", + abi1: `[{"type":"function","name":"test1","inputs":[],"outputs":[]}]`, + abi2: `[{"type":"function","name":"test2","inputs":[],"outputs":[]}]`, + want: false, + }, + { + name: "Different order, same content", + abi1: `[{"type":"function","name":"test1","inputs":[],"outputs":[]},{"type":"function","name":"test2","inputs":[],"outputs":[]}]`, + abi2: `[{"type":"function","name":"test2","inputs":[],"outputs":[]},{"type":"function","name":"test1","inputs":[],"outputs":[]}]`, + want: true, + }, + { + name: "Different input types", + abi1: `[{"type":"function","name":"test","inputs":[{"type":"uint256"}],"outputs":[]}]`, + abi2: `[{"type":"function","name":"test","inputs":[{"type":"uint128"}],"outputs":[]}]`, + want: false, + }, + { + name: "Different output types", + abi1: `[{"type":"function","name":"test","inputs":[],"outputs":[{"type":"uint256"}]}]`, + abi2: `[{"type":"function","name":"test","inputs":[],"outputs":[{"type":"uint128"}]}]`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := compareABIs(json.RawMessage(tt.abi1), json.RawMessage(tt.abi2)) + if err != nil { + t.Errorf("compareABIs() error = %v", err) + return + } + if got != tt.want { + t.Errorf("compareABIs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsExcluded(t *testing.T) { + tests := []struct { + name string + contractName string + want bool + }{ + {"Excluded contract", "IERC20", true}, + {"Non-excluded contract", "IMyContract", false}, + {"Another excluded contract", "IEAS", true}, + {"Excluded contract (case-sensitive)", "ierc20", false}, + {"Excluded contract with prefix", "IERC20Extension", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isExcluded(tt.contractName); got != tt.want { + t.Errorf("isExcluded() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNormalizeInternalType(t *testing.T) { + tests := []struct { + name string + internalType string + want string + }{ + {"Replace contract I", "contract ITest", "contract Test"}, + {"Replace enum I", "enum IMyEnum", "enum MyEnum"}, + {"Replace struct I", "struct IMyStruct", "struct MyStruct"}, + {"No replacement needed", "uint256", "uint256"}, + {"Don't replace non-prefix I", "contract TestI", "contract TestI"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := normalizeInternalType(tt.internalType); got != tt.want { + t.Errorf("normalizeInternalType() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestABIItemLess(t *testing.T) { + tests := []struct { + name string + a map[string]interface{} + b map[string]interface{} + want bool + }{ + { + name: "Different types", + a: map[string]interface{}{"type": "constructor"}, + b: map[string]interface{}{"type": "function"}, + want: true, + }, + { + name: "Same type, different names", + a: map[string]interface{}{"type": "function", "name": "a"}, + b: map[string]interface{}{"type": "function", "name": "b"}, + want: true, + }, + { + name: "Same type and name", + a: map[string]interface{}{"type": "function", "name": "test"}, + b: map[string]interface{}{"type": "function", "name": "test"}, + want: false, + }, + { + name: "Constructor vs function", + a: map[string]interface{}{"type": "constructor"}, + b: map[string]interface{}{"type": "function", "name": "test"}, + want: true, + }, + { + name: "Event vs function", + a: map[string]interface{}{"type": "event", "name": "TestEvent"}, + b: map[string]interface{}{"type": "function", "name": "test"}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := abiItemLess(tt.a, tt.b); got != tt.want { + t.Errorf("abiItemLess() = %v, want %v", got, tt.want) + } + }) + } +}