-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,121 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package provider | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
) | ||
|
||
type AccessKeyProvider struct { | ||
cred *Credentials | ||
} | ||
|
||
func NewAccessKeyProvider(accessKeyId, accessKeySecret string) *AccessKeyProvider { | ||
return &AccessKeyProvider{ | ||
cred: &Credentials{ | ||
AccessKeyId: accessKeyId, | ||
AccessKeySecret: accessKeySecret, | ||
}, | ||
} | ||
} | ||
|
||
func (a *AccessKeyProvider) Credentials(ctx context.Context) (*Credentials, error) { | ||
if a.cred.AccessKeyId == "" || a.cred.AccessKeySecret == "" { | ||
return nil, NewNotEnableError(errors.New("AccessKeyId or AccessKeySecret is empty")) | ||
} | ||
|
||
return a.cred, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package provider | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"strings" | ||
"time" | ||
) | ||
|
||
type ChainProvider struct { | ||
providers []CredentialsProvider | ||
|
||
preProvider string | ||
Logger Logger | ||
} | ||
|
||
func NewChainProvider(providers ...CredentialsProvider) *ChainProvider { | ||
if len(providers) == 0 { | ||
return DefaultChainProvider() | ||
} | ||
return &ChainProvider{ | ||
providers: providers, | ||
} | ||
} | ||
|
||
func (c *ChainProvider) Credentials(ctx context.Context) (*Credentials, error) { | ||
var notEnableErrors []string | ||
|
||
for _, p := range c.providers { | ||
cred, err := p.Credentials(ctx) | ||
if err != nil { | ||
if _, ok := err.(*NotEnableError); ok { | ||
c.logger().Debug(fmt.Sprintf("provider %T not enabled will try to next: %s", p, err.Error())) | ||
notEnableErrors = append(notEnableErrors, fmt.Sprintf("provider %T not enabled: %s", p, err.Error())) | ||
continue | ||
} | ||
} | ||
pT := fmt.Sprintf("%T", p) | ||
if err == nil { | ||
if c.preProvider != pT { | ||
c.preProvider = pT | ||
c.logger().Info(fmt.Sprintf("switch to using provider %s", pT)) | ||
} | ||
return cred, nil | ||
} | ||
return cred, fmt.Errorf("get credentials via %s failed: %w", pT, err) | ||
} | ||
return nil, fmt.Errorf("no available credentials providers: %s", strings.Join(notEnableErrors, ", ")) | ||
} | ||
|
||
func (c *ChainProvider) logger() Logger { | ||
if c.Logger != nil { | ||
return c.Logger | ||
} | ||
return defaultLog | ||
} | ||
|
||
func DefaultChainProvider() *ChainProvider { | ||
return NewChainProvider( | ||
NewEnvProvider(EnvProviderOptions{}), | ||
NewOIDCProvider(OIDCProviderOptions{ | ||
RefreshPeriod: time.Minute * 20, | ||
}), | ||
NewEncryptedFileProvider(EncryptedFileProviderOptions{ | ||
RefreshPeriod: time.Minute * 20, | ||
}), | ||
NewECSMetadataProvider(ECSMetadataProviderOptions{ | ||
RefreshPeriod: time.Minute * 20, | ||
}), | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package provider | ||
|
||
import "time" | ||
|
||
type Credentials struct { | ||
AccessKeyId string | ||
AccessKeySecret string | ||
SecurityToken string | ||
Expiration time.Time | ||
} | ||
|
||
func (c Credentials) DeepCopy() Credentials { | ||
return Credentials{ | ||
AccessKeyId: c.AccessKeyId, | ||
AccessKeySecret: c.AccessKeySecret, | ||
SecurityToken: c.SecurityToken, | ||
Expiration: c.Expiration, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
package provider | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strings" | ||
"time" | ||
) | ||
|
||
const ( | ||
defaultExpiryWindow = time.Minute * 5 | ||
defaultECSMetadataServerEndpoint = "http://100.100.100.200" | ||
defaultECSMetadataTokenTTLSeconds = 3600 | ||
defaultClientTimeout = time.Second * 30 | ||
) | ||
|
||
type ECSMetadataProvider struct { | ||
u *Updater | ||
|
||
endpoint string | ||
roleName string | ||
metadataToken string | ||
metadataTokenTTLSeconds int | ||
metadataTokenExp time.Time | ||
|
||
client *http.Client | ||
} | ||
|
||
type ECSMetadataProviderOptions struct { | ||
Endpoint string | ||
Timeout time.Duration | ||
Transport http.RoundTripper | ||
|
||
RoleName string | ||
MetadataTokenTTLSeconds int | ||
|
||
ExpiryWindow time.Duration | ||
RefreshPeriod time.Duration | ||
} | ||
|
||
func NewECSMetadataProvider(opts ECSMetadataProviderOptions) *ECSMetadataProvider { | ||
opts.applyDefaults() | ||
|
||
client := &http.Client{ | ||
Transport: opts.Transport, | ||
Timeout: opts.Timeout, | ||
} | ||
e := &ECSMetadataProvider{ | ||
endpoint: opts.Endpoint, | ||
roleName: opts.RoleName, | ||
metadataTokenTTLSeconds: opts.MetadataTokenTTLSeconds, | ||
client: client, | ||
} | ||
e.u = NewUpdater(e.getCredentials, UpdaterOptions{ | ||
ExpiryWindow: opts.ExpiryWindow, | ||
RefreshPeriod: opts.RefreshPeriod, | ||
}) | ||
e.u.Start(context.TODO()) | ||
|
||
return e | ||
} | ||
|
||
func (e *ECSMetadataProvider) Credentials(ctx context.Context) (*Credentials, error) { | ||
return e.u.Credentials(ctx) | ||
} | ||
|
||
type ecsMetadataStsResponse struct { | ||
AccessKeyId string `json:"AccessKeyId"` | ||
AccessKeySecret string `json:"AccessKeySecret"` | ||
SecurityToken string `json:"SecurityToken"` | ||
Expiration string `json:"Expiration"` | ||
LastUpdated string `json:"LastUpdated"` | ||
Code string `json:"Code"` | ||
} | ||
|
||
type metadataError struct { | ||
code int | ||
message string | ||
} | ||
|
||
func (e *ECSMetadataProvider) getCredentials(ctx context.Context) (*Credentials, error) { | ||
roleName, err := e.getRoleName(ctx) | ||
if err != nil { | ||
if e, ok := err.(*metadataError); ok && e.code == 404 { | ||
return nil, NewNotEnableError(fmt.Errorf("get role name from ecs metadata failed: %w", err)) | ||
} | ||
} | ||
path := fmt.Sprintf("/latest/meta-data/ram/security-credentials/%s", roleName) | ||
data, err := e.getMedataDataWithToken(ctx, http.MethodGet, path) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var obj ecsMetadataStsResponse | ||
if err := json.Unmarshal([]byte(data), &obj); err != nil { | ||
return nil, fmt.Errorf("parse credentials failed: %w", err) | ||
} | ||
if obj.AccessKeySecret == "" { | ||
return nil, fmt.Errorf("parse credentials got unexpected data: %s", | ||
strings.ReplaceAll(data, "\n", " ")) | ||
} | ||
|
||
exp, err := time.Parse("2006-01-02T15:04:05Z", obj.Expiration) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &Credentials{ | ||
AccessKeyId: obj.AccessKeyId, | ||
AccessKeySecret: obj.AccessKeySecret, | ||
SecurityToken: obj.SecurityToken, | ||
Expiration: exp, | ||
}, nil | ||
} | ||
|
||
func (e *ECSMetadataProvider) getRoleName(ctx context.Context) (string, error) { | ||
if e.roleName != "" { | ||
return e.roleName, nil | ||
} | ||
name, err := e.getMedataDataWithToken(ctx, http.MethodGet, "/latest/meta-data/ram/security-credentials/") | ||
if err != nil { | ||
return "", err | ||
} | ||
return strings.TrimSpace(name), nil | ||
} | ||
|
||
func (e *ECSMetadataProvider) getMedataToken(ctx context.Context) (string, error) { | ||
if !e.metadataTokenExp.Before(time.Now()) { | ||
return e.metadataToken, nil | ||
} | ||
|
||
h := http.Header{} | ||
h.Set("X-aliyun-ecs-metadata-token-ttl-seconds", fmt.Sprintf("%d", e.metadataTokenTTLSeconds)) | ||
body, err := e.getMedataData(ctx, http.MethodPut, "/latest/api/token", h) | ||
if err != nil { | ||
return "", fmt.Errorf("get metadata token failed: %w", err) | ||
} | ||
|
||
e.metadataToken = strings.TrimSpace(body) | ||
e.metadataTokenExp = time.Now().Add(time.Duration(float64(e.metadataTokenTTLSeconds)*0.8) * time.Second) | ||
|
||
return body, nil | ||
} | ||
|
||
func (e *ECSMetadataProvider) getMedataDataWithToken(ctx context.Context, method, path string) (string, error) { | ||
token, err := e.getMedataToken(ctx) | ||
if err != nil { | ||
if e, ok := err.(*metadataError); !(ok && e.code == 404) { | ||
return "", err | ||
} | ||
} | ||
h := http.Header{} | ||
if token != "" { | ||
h.Set("X-aliyun-ecs-metadata-token", token) | ||
} | ||
return e.getMedataData(ctx, method, path, h) | ||
} | ||
|
||
func (e *ECSMetadataProvider) getMedataData(ctx context.Context, method, path string, header http.Header) (string, error) { | ||
url := fmt.Sprintf("%s%s", e.endpoint, path) | ||
req, err := http.NewRequest(method, url, nil) | ||
if err != nil { | ||
return "", fmt.Errorf("can not init request with url %s: %w", url, err) | ||
} | ||
req = req.WithContext(ctx) | ||
req.Header.Set("User-Agent", UserAgent) | ||
for k, items := range header { | ||
for _, v := range items { | ||
req.Header.Add(k, v) | ||
} | ||
} | ||
|
||
resp, err := e.client.Do(req) | ||
if err != nil { | ||
return "", fmt.Errorf("request %s failed: %w", url, err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
data, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return "", fmt.Errorf("read body failed when request %s: %w", url, err) | ||
} | ||
if resp.StatusCode != http.StatusOK { | ||
return "", &metadataError{ | ||
code: resp.StatusCode, | ||
message: fmt.Sprintf("status code %d is not 200 when request %s: %s", | ||
resp.StatusCode, url, strings.ReplaceAll(string(data), "\n", " ")), | ||
} | ||
} | ||
return string(data), nil | ||
} | ||
|
||
func (o *ECSMetadataProviderOptions) applyDefaults() { | ||
if o.Timeout <= 0 { | ||
o.Timeout = defaultClientTimeout | ||
} | ||
if o.Transport == nil { | ||
ts := http.DefaultTransport.(*http.Transport).Clone() | ||
o.Transport = ts | ||
} | ||
if o.Endpoint == "" { | ||
o.Endpoint = defaultECSMetadataServerEndpoint | ||
} else { | ||
o.Endpoint = strings.TrimRight(o.Endpoint, "/") | ||
} | ||
if o.MetadataTokenTTLSeconds == 0 { | ||
o.MetadataTokenTTLSeconds = defaultECSMetadataTokenTTLSeconds | ||
} | ||
if o.ExpiryWindow == 0 { | ||
o.ExpiryWindow = defaultExpiryWindow | ||
} | ||
} | ||
|
||
func (e metadataError) Error() string { | ||
return e.message | ||
} |
Oops, something went wrong.