Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion x/ref/runtime/internal/cloudvm.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type asyncChooser struct {
func (ac *asyncChooser) ChooseAddresses(protocol string, candidates []net.Addr) ([]net.Addr, error) {
select {
case <-ac.ch:
if cvmErr != nil {
return nil, cvmErr
}
return cvm.ChooseAddresses(protocol, candidates)
case <-ac.ctx.Done():
return nil, ac.ctx.Err()
Expand Down Expand Up @@ -115,7 +118,7 @@ func newCloudVM(ctx context.Context, logger logging.Logger, fl *flags.Virtualize

switch fl.VirtualizationProvider.Get().(flags.VirtualizationProvider) {
case flags.AWS:
if !cloudvm.OnAWS(ctx, time.Second) {
if !cloudvm.OnAWS(ctx, cvm.logger, time.Second) {
if fl.DissallowNativeFallback {
return nil, fmt.Errorf("this process is not running on AWS even though its command line says it is")
}
Expand Down
64 changes: 44 additions & 20 deletions x/ref/runtime/internal/cloudvm/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"sync"
"time"

"v.io/v23/logging"
"v.io/x/ref/lib/stats"
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
)
Expand Down Expand Up @@ -56,65 +57,88 @@ const (
)

var (
onceAWS sync.Once
onAWS bool
onceAWS sync.Once
onAWS bool
onIMDSv2 bool
)

// OnAWS returns true if this process is running on Amazon Web Services.
// If true, the the stats variables AWSAccountIDStatName and GCPRegionStatName
// are set.
func OnAWS(ctx context.Context, timeout time.Duration) bool {
func OnAWS(ctx context.Context, logger logging.Logger, timeout time.Duration) bool {
onceAWS.Do(func() {
onAWS = awsInit(ctx, timeout)
onAWS, onIMDSv2 = awsInit(ctx, logger, timeout)
logger.VI(1).Infof("OnAWS: onAWS: %v, onIMDSv2: %v", onAWS, onIMDSv2)
})
return onAWS
}

// AWSPublicAddrs returns the current public IP of this AWS instance.
// Must be called after OnAWS.
func AWSPublicAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) {
return awsGetAddr(ctx, awsExternalURL(), timeout)
return awsGetAddr(ctx, onIMDSv2, awsExternalURL(), timeout)
}

// AWSPrivateAddrs returns the current private Addrs of this AWS instance.
// Must be called after OnAWS.
func AWSPrivateAddrs(ctx context.Context, timeout time.Duration) ([]net.Addr, error) {
return awsGetAddr(ctx, awsInternalURL(), timeout)
return awsGetAddr(ctx, onIMDSv2, awsInternalURL(), timeout)
}

func awsGet(ctx context.Context, url string, timeout time.Duration) ([]byte, error) {
func awsGet(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]byte, error) {
client := &http.Client{Timeout: timeout}
token, err := awsSetIMDSv2Token(ctx, awsTokenURL(), timeout)
if err != nil {
return nil, err
var token string
var err error
if imdsv2 {
token, err = awsSetIMDSv2Token(ctx, awsTokenURL(), timeout)
if err != nil {
return nil, err
}
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
req.Header.Add("X-aws-ec2-metadata-token", token)
if err != nil {
return nil, err
}
if len(token) > 0 {
req.Header.Add("X-aws-ec2-metadata-token", token)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, err
return nil, fmt.Errorf("HTTP Error: %v %v", url, resp.StatusCode)
}
if server := resp.Header["Server"]; len(server) != 1 || server[0] != "EC2ws" {
return nil, fmt.Errorf("wrong headers")
}
return ioutil.ReadAll(resp.Body)
}

// awsInit returns true if it can access AWS project metadata. It also
// awsInit returns true if it can access AWS project metadata and the version
// of the metadata service it was able to access. It also
// creates two stats variables with the account ID and zone.
func awsInit(ctx context.Context, timeout time.Duration) bool {
body, err := awsGet(ctx, awsIdentityDocURL(), timeout)
func awsInit(ctx context.Context, logger logging.Logger, timeout time.Duration) (bool, bool) {
v2 := false
// Try the v1 service first since it should always work unless v2
// is specifically configured (and hence v1 is disabled), in which
// case the expectation is that it fails fast with a 4xx HTTP error.
body, err := awsGet(ctx, false, awsIdentityDocURL(), timeout)
if err != nil {
return false
logger.VI(1).Infof("failed to access v1 metadata service: %v", err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding logs, should help a lot with debugging problems like this in the future :)

// can't access v1, try v2.
body, err = awsGet(ctx, true, awsIdentityDocURL(), timeout)
if err != nil {
logger.VI(1).Infof("failed to access v2 metadata service: %v", err)
return false, false
}
v2 = true
}
doc := map[string]interface{}{}
if err := json.Unmarshal(body, &doc); err != nil {
return false
logger.VI(1).Infof("failed to unmarshal metadata service response: %s: %v", body, err)
return false, false
}
found := 0
for _, v := range []struct {
Expand All @@ -130,11 +154,11 @@ func awsInit(ctx context.Context, timeout time.Duration) bool {
}
}
}
return found == 2
return found == 2, v2
}

func awsGetAddr(ctx context.Context, url string, timeout time.Duration) ([]net.Addr, error) {
body, err := awsGet(ctx, url, timeout)
func awsGetAddr(ctx context.Context, imdsv2 bool, url string, timeout time.Duration) ([]net.Addr, error) {
body, err := awsGet(ctx, imdsv2, url, timeout)
if err != nil {
return nil, err
}
Expand Down
19 changes: 14 additions & 5 deletions x/ref/runtime/internal/cloudvm/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,30 @@ import (
"testing"
"time"

"v.io/x/ref/internal/logger"
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
"v.io/x/ref/runtime/internal/cloudvm/cloudvmtest"
)

func startAWSMetadataServer(t *testing.T) (string, func()) {
host, close := cloudvmtest.StartAWSMetadataServer(t)
func startAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) {
host, close := cloudvmtest.StartAWSMetadataServer(t, imdsv2Only)
SetAWSMetadataHost(host)
return host, close
}

func TestAWS(t *testing.T) {
testAWSIDMSVersion(t, false)
testAWSIDMSVersion(t, true)
}

func testAWSIDMSVersion(t *testing.T, imdsv2Only bool) {
ctx := context.Background()
host, stop := startAWSMetadataServer(t)
host, stop := startAWSMetadataServer(t, imdsv2Only)
defer stop()

if got, want := OnAWS(ctx, time.Second), true; got != want {
logger := logger.NewLogger("test")

if got, want := OnAWS(ctx, logger, time.Second), true; got != want {
t.Errorf("got %v, want %v", got, want)
}

Expand All @@ -45,8 +53,9 @@ func TestAWS(t *testing.T) {
if got, want := pub[0].String(), cloudvmtest.WellKnownPublicIP; got != want {
t.Errorf("got %v, want %v", got, want)
}

externalURL := host + cloudpaths.AWSPublicIPPath + "/noip"
noip, err := awsGetAddr(ctx, externalURL, time.Second)
noip, err := awsGetAddr(ctx, imdsv2Only, externalURL, time.Second)
if err != nil {
t.Fatal(err)
}
Expand Down
24 changes: 17 additions & 7 deletions x/ref/runtime/internal/cloudvm/cloudvmtest/aws_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ import (
"v.io/x/ref/runtime/internal/cloudvm/cloudpaths"
)

func StartAWSMetadataServer(t *testing.T) (string, func()) {
func StartAWSMetadataServer(t *testing.T, imdsv2Only bool) (string, func()) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
var token string
http.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) {
mux := &http.ServeMux{}
mux.HandleFunc(cloudpaths.AWSTokenPath, func(w http.ResponseWriter, req *http.Request) {
token = time.Now().String()
w.Header().Add("Server", "EC2ws")
fmt.Fprint(w, token)
Expand All @@ -32,7 +33,13 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) {
return requestToken == token
}

http.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc(cloudpaths.AWSIdentityDocPath, func(w http.ResponseWriter, r *http.Request) {
if imdsv2Only {
if len(r.Header.Get("X-aws-ec2-metadata-token")) == 0 {
w.WriteHeader(http.StatusUnauthorized)
return
}
}
if !validSession(r) {
w.WriteHeader(http.StatusForbidden)
return
Expand All @@ -58,19 +65,22 @@ func StartAWSMetadataServer(t *testing.T) (string, func()) {
fmt.Fprintf(w, format, args...)
}

http.HandleFunc(cloudpaths.AWSPrivateIPPath,
mux.HandleFunc(cloudpaths.AWSPrivateIPPath,
func(w http.ResponseWriter, r *http.Request) {
respond(w, r, WellKnownPrivateIP)
})
http.HandleFunc(cloudpaths.AWSPublicIPPath,
mux.HandleFunc(cloudpaths.AWSPublicIPPath,
func(w http.ResponseWriter, r *http.Request) {
respond(w, r, WellKnownPublicIP)
})
http.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip",
mux.HandleFunc(cloudpaths.AWSPublicIPPath+"/noip",
func(w http.ResponseWriter, r *http.Request) {
respond(w, r, "")
})

go http.Serve(l, nil)
srv := http.Server{
Handler: mux,
}
go srv.Serve(l)
return "http://" + l.Addr().String(), func() { l.Close() }
}
2 changes: 1 addition & 1 deletion x/ref/runtime/internal/cloudvm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func hasAddr(addrs []net.Addr, host string) bool {
}

func TestCloudVMProviders(t *testing.T) {
awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t)
awsHost, awsClose := cloudvmtest.StartAWSMetadataServer(t, true)
defer awsClose()
cloudvm.SetAWSMetadataHost(awsHost)

Expand Down