Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func init() {
new(S3Detector),
new(GCSDetector),
new(FileDetector),
new(AzureBlobDetector),
}
}

Expand Down
45 changes: 45 additions & 0 deletions detect_azure_blob.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package getter

import (
"fmt"
"net/url"
"strings"
)

type AzureBlobDetector struct{}

var azureStorageSuffixes = []string{
"blob.core.windows.net",
"blob.core.usgovcloudapi.net",
"blob.core.chinacloudapi.cn",
}

func (d *AzureBlobDetector) Detect(src, pwd string) (string, bool, error) {
if len(src) == 0 {
return "", false, nil
}

for _, s := range azureStorageSuffixes {
if strings.Contains(src, s) {
return d.detectURL(src)
}
}

return "", false, nil
}

func (d *AzureBlobDetector) detectURL(src string) (string, bool, error) {
u, err := url.Parse(src)
if err != nil {
return "", false, err
}

parts := strings.Split(u.Path, "/")
if len(parts) < 2 {
return "", false, fmt.Errorf("path to blob must not be empty")
}

u.Scheme = "https"

return fmt.Sprintf("azureblob::%s", u.String()), true, nil
}
46 changes: 46 additions & 0 deletions detect_azure_blob_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package getter

import (
"testing"
)

func TestAzureBlobDetector(t *testing.T) {
cases := []struct {
Input string
Output string
}{
{
"account.blob.core.windows.net/foo/bar",
"azureblob::https://account.blob.core.windows.net/foo/bar",
},
{
"account.blob.core.usgovcloudapi.net/foo/bar",
"azureblob::https://account.blob.core.usgovcloudapi.net/foo/bar",
},
{
"account.blob.core.chinacloudapi.cn/foo/bar",
"azureblob::https://account.blob.core.chinacloudapi.cn/foo/bar",
},
// Misc tests
{
"account.blob.core.windows.net/foo/bar?version=1234",
"azureblob::https://account.blob.core.windows.net/foo/bar?version=1234",
},
}

pwd := "/pwd"
f := new(AzureBlobDetector)
for i, tc := range cases {
output, ok, err := f.Detect(tc.Input, pwd)
if err != nil {
t.Fatalf("err: %s", err)
}
if !ok {
t.Fatal("not ok")
}

if output != tc.Output {
t.Fatalf("%d: bad: %#v", i, output)
}
}
}
15 changes: 8 additions & 7 deletions get.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ func init() {
}

Getters = map[string]Getter{
"file": new(FileGetter),
"git": new(GitGetter),
"gcs": new(GCSGetter),
"hg": new(HgGetter),
"s3": new(S3Getter),
"http": httpGetter,
"https": httpGetter,
"file": new(FileGetter),
"git": new(GitGetter),
"gcs": new(GCSGetter),
"hg": new(HgGetter),
"s3": new(S3Getter),
"azureblob": new(AzureBlobGetter),
"http": httpGetter,
"https": httpGetter,
}
}

Expand Down
267 changes: 267 additions & 0 deletions get_azure_blob.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package getter

import (
"context"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
)

type AzureBlobGetter struct {
getter

// Timeout sets a deadline which all AzureBlob operations should
// complete within. Zero value means no timeout.
Timeout time.Duration
}

// Get downloads the given URL into the given directory. This always
// assumes that we're updating and gets the latest version that it can.
//
// The directory may already exist (if we're updating). If it is in a
// format that isn't understood, an error should be returned. Get shouldn't
// simply nuke the directory.
func (g *AzureBlobGetter) Get(dst string, url *url.URL) error {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
accountName, baseUrl, containerName, blobPath, _, err := g.parseUrl(url)
if err != nil {
return err
}

// Remove destination if it already exists
_, err = os.Stat(dst)
if err != nil && !os.IsNotExist(err) {
return err
}

if err == nil {
// Remove the destination
if err := os.RemoveAll(dst); err != nil {
return err
}
}

// Create client config
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return err
}
containerUrl, err := url.Parse(fmt.Sprintf("https://%s.%s/%s", accountName, baseUrl, containerName))
if err != nil {
return err
}

client, err := g.newAzureContainerClient(containerUrl, cred)
if err != nil {
return err
}

pager := client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
Prefix: &blobPath,
})

for pager.More() {
r, err := pager.NextPage(ctx)
if err != nil {
return err
}
for _, item := range r.ListBlobsFlatSegmentResponse.Segment.BlobItems {
blobFullName := *item.Name

blobPathPart := filepath.Dir(blobFullName)
blobName := filepath.Base(blobFullName)
dstPathPart := strings.TrimPrefix(blobPathPart, blobPath)

dst := strings.Join([]string{dst, dstPathPart}, "/")

blobClient := client.NewBlobClient(blobFullName)
err = os.MkdirAll(dst, os.ModeDir)
if err != nil {
return err
}

f, err := os.Create(dst + "/" + blobName)
if err != nil {
return err
}
defer f.Close()
_, err = blobClient.DownloadFile(ctx, f, nil)
if err != nil {
return err
}
}
}

return nil
}

// GetFile downloads the give URL into the given path. The URL must
// reference a single file. If possible, the Getter should check if
// the remote end contains the same file and no-op this operation.
func (g *AzureBlobGetter) GetFile(dst string, url *url.URL) error {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
accountName, baseUrl, containerName, blobPath, _, err := g.parseUrl(url)
if err != nil {
return err
}

// Remove destination if it already exists
_, err = os.Stat(dst)
if err != nil && !os.IsNotExist(err) {
return err
}

if err == nil {
// Remove the destination
if err := os.RemoveAll(dst); err != nil {
return err
}
}

// Create client config
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return err
}
containerUrl, err := url.Parse(fmt.Sprintf("https://%s.%s/%s", accountName, baseUrl, containerName))
if err != nil {
return err
}

client, err := g.newAzureContainerClient(containerUrl, cred)
if err != nil {
return err
}

blobClient := client.NewBlobClient(blobPath)
err = os.MkdirAll(filepath.Dir(dst), os.ModeDir)
if err != nil {
return err
}

f, err := os.Create(dst)
if err != nil {
return err
}
defer f.Close()
_, err = blobClient.DownloadFile(ctx, f, nil)
if err != nil {
return err
}
return nil
}

// ClientMode returns the mode based on the given URL. This is used to
// allow clients to let the getters decide which mode to use.
func (g *AzureBlobGetter) ClientMode(url *url.URL) (ClientMode, error) {
ctx := g.Context()

if g.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, g.Timeout)
defer cancel()
}

// Parse URL
accountName, baseUrl, containerName, blobPath, _, err := g.parseUrl(url)
if err != nil {
return ClientModeInvalid, err
}
if blobPath == "" {
// Root Path so use DirMode
return ClientModeDir, nil
}

// Create client config
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return ClientModeInvalid, err
}
containerUrl, err := url.Parse(fmt.Sprintf("https://%s.%s/%s", accountName, baseUrl, containerName))
if err != nil {
return ClientModeInvalid, err
}

client, err := g.newAzureContainerClient(containerUrl, cred)
if err != nil {
return ClientModeInvalid, err
}

pager := client.NewListBlobsFlatPager(&container.ListBlobsFlatOptions{
Prefix: &blobPath,
})

for pager.More() {
r, err := pager.NextPage(ctx)
if err != nil {
return ClientModeInvalid, err
}
blobs := r.ListBlobsFlatSegmentResponse.Segment.BlobItems
if len(blobs) == 1 && *blobs[0].Name == blobPath {
return ClientModeFile, nil
} else {
return ClientModeDir, nil
}
}
return ClientModeInvalid, nil

}

func (g *AzureBlobGetter) newAzureContainerClient(url *url.URL, cred azcore.TokenCredential) (client *container.Client, err error) {

client, err = container.NewClient(url.String(), cred, nil)
return
}

func (g *AzureBlobGetter) parseUrl(u *url.URL) (accountName, baseURL, container, blobPath, accessKey string, err error) {
// Expected host style: accountname.blob.core.windows.net.
// The last 3 parts will be different across environments.
hostParts := strings.SplitN(u.Host, ".", 2)
if len(hostParts) != 2 {
err = fmt.Errorf("URL is not a valid Azure Blob URL: %v", hostParts)
return
}

accountName = hostParts[0]
baseURL = hostParts[1]

pathParts := strings.SplitN(strings.TrimPrefix(u.Path, "/"), "/", 2)
if len(pathParts) < 1 {
err = fmt.Errorf("URL is not a valid Azure Blob URL: %v", pathParts)
return
}

container = pathParts[0]
if len(pathParts) > 1 {
blobPath = pathParts[1]
} else {
blobPath = ""
}

accessKey = u.Query().Get("access_key")

return
}
Loading