Skip to content

Commit

Permalink
eth: Add authenticated geth rpc capabilities.
Browse files Browse the repository at this point in the history
Add ability for client and server to connect over authenticated
websocket to a geth full node.
  • Loading branch information
JoeGruffins committed Nov 25, 2022
1 parent 9ef398e commit f849ce0
Show file tree
Hide file tree
Showing 23 changed files with 790 additions and 313 deletions.
68 changes: 48 additions & 20 deletions client/asset/eth/eth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
Expand Down Expand Up @@ -73,7 +74,7 @@ const (
walletTypeGeth = "geth"
walletTypeRPC = "rpc"

providersKey = "providers"
providersKey = "providersv1"

// confCheckTimeout is the amount of time allowed to check for
// confirmations. Testing on testnet has shown spikes up to 2.5
Expand Down Expand Up @@ -101,13 +102,13 @@ var (
}
RPCOpts = []*asset.ConfigOption{
{
Key: providersKey,
DisplayName: "Provider",
Description: "Specify one or more providers. For infrastructure " +
Key: providersKey,
RepeatableDisplayName: []string{"Provider", "jwt secret"},
RepeatableDescription: []string{"Specify one or more providers. For infrastructure " +
"providers, use an https address. Only url-based authentication " +
"is supported. For a local node, use the filepath to an IPC file.",
Repeatable: providerDelimiter,
Required: true,
"Specify a jwt secret if communication with a geth full node over ws."},
Required: true,
},
}
// WalletInfo defines some general information about a Ethereum wallet.
Expand Down Expand Up @@ -515,6 +516,31 @@ func CreateWallet(cfg *asset.CreateWalletParams) error {
return createWallet(cfg, false)
}

// endpointsFromSettings parses endpoints from the setting map. Endpoints are
// stored as and array of and array of strings.
func endpointsFromSettings(settings map[string]string) ([]endpoint, error) {
providerDef := settings[providersKey]
if len(providerDef) == 0 {
return nil, errors.New("no providers specified")
}
var values [][]string
err := json.Unmarshal([]byte(providerDef), &values)
if err != nil {
return nil, err
}
endpoints := make([]endpoint, len(values))
for i, v := range values {
switch len(v) {
case 2:
endpoints[i].jwt = v[1]
fallthrough
case 1:
endpoints[i].addr = v[0]
}
}
return endpoints, nil
}

func createWallet(createWalletParams *asset.CreateWalletParams, skipConnect bool) error {
switch createWalletParams.Type {
case walletTypeGeth:
Expand Down Expand Up @@ -557,12 +583,10 @@ func createWallet(createWalletParams *asset.CreateWalletParams, skipConnect bool
case walletTypeRPC:

// Check that we can connect to all endpoints.
providerDef := createWalletParams.Settings[providersKey]
if len(providerDef) == 0 {
return errors.New("no providers specified")
endpoints, err := endpointsFromSettings(createWalletParams.Settings)
if err != nil {
return fmt.Errorf("unable to read endpoints: %v", err)
}
endpoints := strings.Split(providerDef, providerDelimiter)
n := len(endpoints)

// TODO: This procedure may actually work for walletTypeGeth too.
ks := keystore.NewKeyStore(filepath.Join(walletDir, "keystore"), keystore.LightScryptN, keystore.LightScryptP)
Expand All @@ -576,14 +600,15 @@ func createWallet(createWalletParams *asset.CreateWalletParams, skipConnect bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

var unknownEndpoints []string
var unknownEndpoints []endpoint

for _, endpoint := range endpoints {
known, compliant := providerIsCompliant(endpoint)
for _, ep := range endpoints {
addr := ep.addr
known, compliant := providerIsCompliant(addr)
if known && !compliant {
return fmt.Errorf("provider %q is known to have an insufficient API for DEX", endpoint)
return fmt.Errorf("provider %q is known to have an insufficient API for DEX", addr)
} else if !known {
unknownEndpoints = append(unknownEndpoints, endpoint)
unknownEndpoints = append(unknownEndpoints, ep)
}
}

Expand All @@ -597,8 +622,8 @@ func createWallet(createWalletParams *asset.CreateWalletParams, skipConnect bool
p.ec.Close()
}
}()
if len(providers) != n {
return fmt.Errorf("Could not connect to all providers")
if len(providers) != len(endpoints) {
return errors.New("could not connect to all providers")
}
if err := checkProvidersCompliance(ctx, walletDir, providers, createWalletParams.Logger); err != nil {
return err
Expand Down Expand Up @@ -727,7 +752,10 @@ func (w *ETHWallet) Connect(ctx context.Context) (_ *sync.WaitGroup, err error)
// }
return nil, asset.ErrWalletTypeDisabled
case walletTypeRPC:
endpoints := strings.Split(w.settings[providersKey], " ")
endpoints, err := endpointsFromSettings(w.settings)
if err != nil {
return nil, fmt.Errorf("unable to read endpoints: %v", err)
}
ethCfg, err := ethChainConfig(w.net)
if err != nil {
return nil, err
Expand All @@ -737,7 +765,7 @@ func (w *ETHWallet) Connect(ctx context.Context) (_ *sync.WaitGroup, err error)
// Point to a harness node on simnet, if not specified.
if w.net == dex.Simnet && len(endpoints) == 0 {
u, _ := user.Current()
endpoints = append(endpoints, filepath.Join(u.HomeDir, "dextest", "eth", "beta", "node", "geth.ipc"))
endpoints = append(endpoints, endpoint{addr: filepath.Join(u.HomeDir, "dextest", "eth", "beta", "node", "geth.ipc")})
}

cl, err = newMultiRPCClient(w.dir, endpoints, w.log.SubLogger("RPC"), chainConfig, big.NewInt(chainIDs[w.net]), w.net)
Expand Down
6 changes: 3 additions & 3 deletions client/asset/eth/eth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2936,7 +2936,7 @@ func TestDriverOpen(t *testing.T) {
logger := dex.StdOutLogger("ETHTEST", dex.LevelOff)
tmpDir := t.TempDir()

settings := map[string]string{providersKey: "a.ipc"}
settings := map[string]string{providersKey: `[["a.ipc",""]]`}
err := createWallet(&asset.CreateWalletParams{
Type: walletTypeRPC,
Seed: encode.RandomBytes(32),
Expand Down Expand Up @@ -2987,7 +2987,7 @@ func TestDriverExists(t *testing.T) {
drv := &Driver{}
tmpDir := t.TempDir()

settings := map[string]string{providersKey: "a.ipc"}
settings := map[string]string{providersKey: `[["a.ipc",""]]`}

// no wallet
exists, err := drv.Exists(walletTypeRPC, tmpDir, settings, dex.Simnet)
Expand Down Expand Up @@ -4554,7 +4554,7 @@ func testMaxSwapRedeemLots(t *testing.T, assetID uint32) {
logger := dex.StdOutLogger("ETHTEST", dex.LevelOff)
tmpDir := t.TempDir()

settings := map[string]string{providersKey: "a.ipc"}
settings := map[string]string{providersKey: `[["a.ipc",""]]`}
err := createWallet(&asset.CreateWalletParams{
Type: walletTypeRPC,
Seed: encode.RandomBytes(32),
Expand Down
62 changes: 41 additions & 21 deletions client/asset/eth/multirpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"fmt"
"math/big"
"math/rand"
"net/http"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -206,7 +207,7 @@ func (p *provider) subscribeHeaders(ctx context.Context, sub ethereum.Subscripti
return sub, nil
}
if time.Since(lastWarning) > 5*time.Minute {
log.Warnf("can't resubscribe to %q headers: %v", err)
log.Warnf("can't resubscribe to %q headers: %v", p.host, err)
}
select {
case <-time.After(time.Second * 30):
Expand Down Expand Up @@ -254,7 +255,7 @@ func (p *provider) subscribeHeaders(ctx context.Context, sub ethereum.Subscripti
return
}
log.Errorf("%q header subscription error: %v", p.host, err)
log.Info("Attempting to resubscribe to %q block headers", p.host)
log.Infof("Attempting to resubscribe to %q block headers", p.host)
sub, err = newSub()
if err != nil { // context cancelled
return
Expand Down Expand Up @@ -282,6 +283,11 @@ type receiptRecord struct {
confirmed bool
}

type endpoint struct {
addr string
jwt string
}

// multiRPCClient is an ethFetcher backed by one or more public RPC providers.
type multiRPCClient struct {
cfg *params.ChainConfig
Expand All @@ -290,7 +296,7 @@ type multiRPCClient struct {
chainID *big.Int

providerMtx sync.Mutex
endpoints []string
endpoints []endpoint
providers []*provider

lastNonce struct {
Expand All @@ -316,7 +322,7 @@ type multiRPCClient struct {

var _ ethFetcher = (*multiRPCClient)(nil)

func newMultiRPCClient(dir string, endpoints []string, log dex.Logger, cfg *params.ChainConfig, chainID *big.Int, net dex.Network) (*multiRPCClient, error) {
func newMultiRPCClient(dir string, endpoints []endpoint, log dex.Logger, cfg *params.ChainConfig, chainID *big.Int, net dex.Network) (*multiRPCClient, error) {
walletDir := getWalletDir(dir, net)
creds, err := pathCredentials(filepath.Join(walletDir, "keystore"))
if err != nil {
Expand All @@ -340,7 +346,7 @@ func newMultiRPCClient(dir string, endpoints []string, log dex.Logger, cfg *para
// list of providers that were successfully connected. It is not an error for a
// connection to fail. The caller can infer failed connections from the length
// and contents of the returned provider list.
func connectProviders(ctx context.Context, endpoints []string, log dex.Logger, chainID *big.Int) ([]*provider, error) {
func connectProviders(ctx context.Context, endpoints []endpoint, log dex.Logger, chainID *big.Int) ([]*provider, error) {
providers := make([]*provider, 0, len(endpoints))
var success bool

Expand All @@ -352,7 +358,7 @@ func connectProviders(ctx context.Context, endpoints []string, log dex.Logger, c
}
}()

for _, endpoint := range endpoints {
for _, ep := range endpoints {
// First try to get a websocket connection. Websockets have a header
// feed, so are much preferred to http connections. So much so, that
// we'll do some path inspection here and make an attempt to find a
Expand All @@ -362,10 +368,14 @@ func connectProviders(ctx context.Context, endpoints []string, log dex.Logger, c
var sub ethereum.Subscription
var h chan *types.Header
host := providerIPC
if !strings.HasSuffix(endpoint, ".ipc") {
wsURL, err := url.Parse(endpoint)
addr := ep.addr
if strings.HasSuffix(addr, ".ipc") {
// Clean file path.
addr = dex.CleanAndExpandPath(addr)
} else {
wsURL, err := url.Parse(addr)
if err != nil {
return nil, fmt.Errorf("Failed to parse url %q", endpoint)
return nil, fmt.Errorf("Failed to parse url %q", addr)
}
host = wsURL.Host
ogScheme := wsURL.Scheme
Expand All @@ -376,7 +386,7 @@ func connectProviders(ctx context.Context, endpoints []string, log dex.Logger, c
wsURL.Scheme = "ws"
case "ws", "wss":
default:
return nil, fmt.Errorf("unknown scheme for endpoint %q: %q", endpoint, wsURL.Scheme)
return nil, fmt.Errorf("unknown scheme for endpoint %q: %q", addr, wsURL.Scheme)
}
replaced := ogScheme != wsURL.Scheme

Expand All @@ -392,7 +402,18 @@ func connectProviders(ctx context.Context, endpoints []string, log dex.Logger, c
host = providerRivetCloud
}

rpcClient, err = rpc.DialWebsocket(ctx, wsURL.String(), "")
if ep.jwt == "" {
rpcClient, err = rpc.DialWebsocket(ctx, wsURL.String(), "")
} else {
// Geth clients should always be able to get a
// websocket connection, making http unnecessary.
var authFn func(h http.Header) error
authFn, err = dexeth.JWTHTTPAuthFn(ep.jwt)
if err != nil {
return nil, fmt.Errorf("unable to create auth function: %v", err)
}
rpcClient, err = rpc.DialOptions(ctx, wsURL.String(), rpc.WithHTTPAuth(authFn))
}
if err == nil {
ec = ethclient.NewClient(rpcClient)
h = make(chan *types.Header, 8)
Expand All @@ -410,17 +431,17 @@ func connectProviders(ctx context.Context, endpoints []string, log dex.Logger, c
if replaced {
log.Debugf("couldn't get a websocket connection for %q (original scheme: %q) (OK)", wsURL, ogScheme)
} else {
log.Errorf("failed to get websocket connection to %q. attempting http(s) fallback: error = %v", endpoint, err)
log.Errorf("failed to get websocket connection to %q. attempting http(s) fallback: error = %v", addr, err)
}
}
}
// Weren't able to get a websocket connection. Try HTTP now. Dial does
// path discrimination, so I won't even try to validate the protocol.
if ec == nil {
var err error
rpcClient, err = rpc.Dial(endpoint)
rpcClient, err = rpc.Dial(addr)
if err != nil {
log.Errorf("error creating http client for %q: %v", endpoint, err)
log.Errorf("error creating http client for %q: %v", addr, err)
continue
}
ec = ethclient.NewClient(rpcClient)
Expand All @@ -431,20 +452,20 @@ func connectProviders(ctx context.Context, endpoints []string, log dex.Logger, c
if err != nil {
// If we can't get a header, don't use this provider.
ec.Close()
log.Errorf("Failed to get chain ID from %q: %v", endpoint, err)
log.Errorf("Failed to get chain ID from %q: %v", addr, err)
continue
}
if chainID.Cmp(reportedChainID) != 0 {
ec.Close()
log.Errorf("%q reported wrong chain ID. expected %d, got %d", endpoint, chainID, reportedChainID)
log.Errorf("%q reported wrong chain ID. expected %d, got %d", addr, chainID, reportedChainID)
continue
}

hdr, err := ec.HeaderByNumber(ctx, nil /* latest */)
if err != nil {
// If we can't get a header, don't use this provider.
ec.Close()
log.Errorf("Failed to get header from %q: %v", endpoint, err)
log.Errorf("Failed to get header from %q: %v", addr, err)
continue
}

Expand Down Expand Up @@ -546,11 +567,10 @@ func (m *multiRPCClient) voidUnusedNonce() {
}

func (m *multiRPCClient) reconfigure(ctx context.Context, settings map[string]string) error {
providerDef := settings[providersKey]
if len(providerDef) == 0 {
return errors.New("no providers specified")
endpoints, err := endpointsFromSettings(settings)
if err != nil {
return fmt.Errorf("unable to read endpoints: %v", err)
}
endpoints := strings.Split(providerDef, " ")
providers, err := connectProviders(ctx, endpoints, m.log, m.chainID)
if err != nil {
return err
Expand Down
Loading

0 comments on commit f849ce0

Please sign in to comment.