Skip to content

Commit

Permalink
Merge pull request #4 from HumairAK/s3_support
Browse files Browse the repository at this point in the history
UPSTREAM: <carry>: add ns scoped s3 support.
  • Loading branch information
openshift-merge-bot[bot] authored Jan 19, 2024
2 parents 888eacc + 58ecf2d commit 4e9b4e6
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 56 deletions.
2 changes: 1 addition & 1 deletion backend/src/v2/component/importer_launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (l *ImportLauncher) Execute(ctx context.Context) (err error) {
}()
// TODO(Bobgy): there's no need to pass any parameters, because pipeline
// and pipeline run context have been created by root DAG driver.
pipeline, err := l.metadataClient.GetPipeline(ctx, l.importerLauncherOptions.PipelineName, l.importerLauncherOptions.RunID, "", "", "")
pipeline, err := l.metadataClient.GetPipeline(ctx, l.importerLauncherOptions.PipelineName, l.importerLauncherOptions.RunID, "", "", "", "")
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions backend/src/v2/component/launcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ func (l *LauncherV2) Execute(ctx context.Context) (err error) {
if err != nil {
return err
}
bucket, err := objectstore.OpenBucket(ctx, l.k8sClient, l.options.Namespace, bucketConfig)
bucketSessionInfo := execution.GetPipeline().GetPipelineBucketSession()
bucket, err := objectstore.OpenBucket(ctx, l.k8sClient, l.options.Namespace, bucketConfig, bucketSessionInfo)
if err != nil {
return err
}
Expand Down Expand Up @@ -539,7 +540,7 @@ func fetchNonDefaultBuckets(
if err != nil {
return nonDefaultBuckets, fmt.Errorf("failed to parse bucketConfig for output artifact %q with uri %q: %w", name, artifact.GetUri(), err)
}
nonDefaultBucket, err := objectstore.OpenBucket(ctx, k8sClient, namespace, nonDefaultBucketConfig)
nonDefaultBucket, err := objectstore.OpenBucket(ctx, k8sClient, namespace, nonDefaultBucketConfig, "")
if err != nil {
return nonDefaultBuckets, fmt.Errorf("failed to open bucket for output artifact %q with uri %q: %w", name, artifact.GetUri(), err)
}
Expand Down
177 changes: 177 additions & 0 deletions backend/src/v2/config/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package config
import (
"context"
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"io/ioutil"
"sigs.k8s.io/yaml"
"strings"

"github.com/golang/glog"
Expand Down Expand Up @@ -82,3 +84,178 @@ func InPodName() (string, error) {
name := string(podName)
return strings.TrimSuffix(name, "\n"), nil
}

const (
configBucketProviders = "providers"
// The endpoint uses Kubernetes service DNS name with namespace:
// https://kubernetes.io/docs/concepts/services-networking/service/#dns
defaultMinioEndpointInMultiUserMode = "minio-service.kubeflow:9000"
minioArtifactSecretName = "mlpipeline-minio-artifact"
minioArtifactSecretKeyKey = "secretkey"
minioArtifactAccessKeyKey = "accesskey"
)

type BucketProviders struct {
Minio *ProviderConfig `json:"minio"`
S3 *ProviderConfig `json:"s3"`
GCS *ProviderConfig `json:"gcs"`
}

type ProviderConfig struct {
Endpoint string `json:"endpoint"`
DefaultProviderSecretRef *SecretRef `json:"defaultProviderSecretRef"`
Region string `json:"region"`
// optional
DisableSSL bool `json:"disableSSL"`
// optional, ordered, the auth config for the first matching prefix is used
AuthConfigs []AuthConfig `json:"authConfigs"`
}

type AuthConfig struct {
BucketName string `json:"bucketName"`
KeyPrefix string `json:"keyPrefix"`
*SecretRef `json:"secretRef"`
}

type SecretRef struct {
SecretName string `json:"secretName"`
AccessKeyKey string `json:"accessKeyKey"`
SecretKeyKey string `json:"secretKeyKey"`
}

func (c *Config) GetBucketSessionInfo() (objectstore.SessionInfo, error) {
path := c.DefaultPipelineRoot()
bucketConfig, err := objectstore.ParseBucketConfig(path)
if err != nil {
return objectstore.SessionInfo{}, err
}
bucketName := bucketConfig.BucketName
bucketPrefix := bucketConfig.Prefix
provider := strings.TrimSuffix(bucketConfig.Scheme, "://")
bucketProviders, err := c.getBucketProviders()
if err != nil {
return objectstore.SessionInfo{}, err
}

// Case 1: No "providers" field in kfp-launcher
if bucketProviders == nil {
// Use default minio if provider is minio, otherwise we default to executor env
if provider == "minio" {
return getDefaultMinioSessionInfo(), nil
} else {
// If not using minio, and no other provider config is provided
// rely on executor env (e.g. IRSA) for authenticating with provider
return objectstore.SessionInfo{}, nil
}
}

var providerConfig *ProviderConfig
switch provider {
case "minio":
providerConfig = bucketProviders.Minio
break
case "s3":
providerConfig = bucketProviders.S3
break
case "gs":
providerConfig = bucketProviders.Minio
break
default:
return objectstore.SessionInfo{}, fmt.Errorf("Encountered unsupported provider in BucketProviders %s", provider)
}

// Case 2: "providers" field is empty {}
if providerConfig == nil {
if provider == "minio" {
return getDefaultMinioSessionInfo(), nil
} else {
return objectstore.SessionInfo{}, nil
}
}

// Case 3: a provider is specified
endpoint := providerConfig.Endpoint
if endpoint == "" {
if provider == "minio" {
endpoint = objectstore.MinioDefaultEndpoint()
} else {
return objectstore.SessionInfo{}, fmt.Errorf("Invalid provider config, %s.defaultProviderSecretRef is required for this storage provider", provider)
}
}

// DefaultProviderSecretRef takes precedent over other configs
secretRef := providerConfig.DefaultProviderSecretRef
if secretRef == nil {
if provider == "minio" {
secretRef = &SecretRef{
SecretName: minioArtifactSecretName,
SecretKeyKey: minioArtifactSecretKeyKey,
AccessKeyKey: minioArtifactAccessKeyKey,
}
} else {
return objectstore.SessionInfo{}, fmt.Errorf("Invalid provider config, %s.defaultProviderSecretRef is required for this storage provider", provider)
}
}

// if not provided, defaults to false
disableSSL := providerConfig.DisableSSL

region := providerConfig.Region
if region == "" {
return objectstore.SessionInfo{}, fmt.Errorf("Invalid provider config, missing provider region")
}

// if another secret is specified for a given bucket/prefix then that takes
// higher precedent over DefaultProviderSecretRef
authConfig := getBucketAuthByPrefix(providerConfig.AuthConfigs, bucketName, bucketPrefix)
if authConfig != nil {
if authConfig.SecretRef == nil || authConfig.SecretRef.SecretKeyKey == "" || authConfig.SecretRef.AccessKeyKey == "" || authConfig.SecretRef.SecretName == "" {
return objectstore.SessionInfo{}, fmt.Errorf("Invalid provider config, %s.AuthConfigs[].secretConfig is missing or invalid", provider)
}
secretRef = authConfig.SecretRef
}

return objectstore.SessionInfo{
Region: region,
Endpoint: endpoint,
DisableSSL: disableSSL,
SecretName: secretRef.SecretName,
AccessKeyKey: secretRef.AccessKeyKey,
SecretKeyKey: secretRef.SecretKeyKey,
}, nil
}

func getDefaultMinioSessionInfo() (sessionInfo objectstore.SessionInfo) {
sess := objectstore.SessionInfo{
Region: "minio",
Endpoint: objectstore.MinioDefaultEndpoint(),
DisableSSL: true,
SecretName: minioArtifactSecretName,
AccessKeyKey: minioArtifactAccessKeyKey,
SecretKeyKey: minioArtifactSecretKeyKey,
}
return sess
}

// getBucketProviders gets the provider configuration
func (c *Config) getBucketProviders() (*BucketProviders, error) {
if c == nil || c.data[configBucketProviders] == "" {
return nil, nil
}
bucketProviders := &BucketProviders{}
configAuth := c.data[configBucketProviders]
err := yaml.Unmarshal([]byte(configAuth), bucketProviders)
if err != nil {
return nil, fmt.Errorf("failed to unmarshall kfp bucket providers, ensure that providers config is well formed: %w", err)
}
return bucketProviders, nil
}

func getBucketAuthByPrefix(authConfigs []AuthConfig, bucketName, prefix string) *AuthConfig {
for _, authConfig := range authConfigs {
if authConfig.BucketName == bucketName && (authConfig.KeyPrefix == prefix) {
return &authConfig
}
}
return nil
}
25 changes: 18 additions & 7 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/kubeflow/pipelines/backend/src/v2/objectstore"
"path"
"strconv"
"strings"
Expand Down Expand Up @@ -132,6 +133,7 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio
}
// TODO(v2): in pipeline spec, rename GCS output directory to pipeline root.
pipelineRoot := opts.RuntimeConfig.GetGcsOutputDirectory()
pipelineBucketSessionInfo := objectstore.SessionInfo{}
if pipelineRoot != "" {
glog.Infof("PipelineRoot=%q", pipelineRoot)
} else {
Expand All @@ -149,9 +151,18 @@ func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (executio
}
pipelineRoot = cfg.DefaultPipelineRoot()
glog.Infof("PipelineRoot=%q from default config", pipelineRoot)
pipelineBucketSessionInfo, err = cfg.GetBucketSessionInfo()
if err != nil {
return nil, err
}
}
bucketSessionInfo, err := json.Marshal(pipelineBucketSessionInfo)
if err != nil {
return nil, err
}
bucketSessionInfoEntry := string(bucketSessionInfo)
// TODO(Bobgy): fill in run resource.
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, opts.Namespace, "run-resource", pipelineRoot)
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, opts.Namespace, "run-resource", pipelineRoot, bucketSessionInfoEntry)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -230,7 +241,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl
}
// TODO(Bobgy): there's no need to pass any parameters, because pipeline
// and pipeline run context have been created by root DAG driver.
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "")
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "", "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -535,7 +546,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
}
// TODO(Bobgy): there's no need to pass any parameters, because pipeline
// and pipeline run context have been created by root DAG driver.
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "")
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "", "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1206,7 +1217,7 @@ func createPVC(
// Create execution regardless the operation succeeds or not
defer func() {
if createdExecution == nil {
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "")
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "", "")
if err != nil {
return
}
Expand Down Expand Up @@ -1286,7 +1297,7 @@ func createPVC(
ecfg.CachedMLMDExecutionID = cachedMLMDExecutionID
ecfg.FingerPrint = fingerPrint

pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "")
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "", "")
if err != nil {
return "", createdExecution, pb.Execution_FAILED, fmt.Errorf("error getting pipeline from MLMD: %w", err)
}
Expand Down Expand Up @@ -1376,7 +1387,7 @@ func deletePVC(
// Create execution regardless the operation succeeds or not
defer func() {
if createdExecution == nil {
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "")
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "", "")
if err != nil {
return
}
Expand Down Expand Up @@ -1406,7 +1417,7 @@ func deletePVC(
ecfg.CachedMLMDExecutionID = cachedMLMDExecutionID
ecfg.FingerPrint = fingerPrint

pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "")
pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "", "")
if err != nil {
return createdExecution, pb.Execution_FAILED, fmt.Errorf("error getting pipeline from MLMD: %w", err)
}
Expand Down
50 changes: 32 additions & 18 deletions backend/src/v2/metadata/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ var (
)

type ClientInterface interface {
GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot string) (*Pipeline, error)
GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot, bucketSessionInfo string) (*Pipeline, error)
GetDAG(ctx context.Context, executionID int64) (*DAG, error)
PublishExecution(ctx context.Context, execution *Execution, outputParameters map[string]*structpb.Value, outputArtifacts []*OutputArtifact, state pb.Execution_State) error
CreateExecution(ctx context.Context, pipeline *Pipeline, config *ExecutionConfig) (*Execution, error)
Expand Down Expand Up @@ -200,6 +200,18 @@ func (p *Pipeline) GetCtxID() int64 {
return p.pipelineCtx.GetId()
}

func (p *Pipeline) GetPipelineBucketSession() string {
if p == nil {
return ""
}
props := p.pipelineRunCtx.GetCustomProperties()
bucketSessionInfo, ok := props[keySessionInfoDetails]
if !ok {
return ""
}
return bucketSessionInfo.GetStringValue()
}

func (p *Pipeline) GetPipelineRoot() string {
if p == nil {
return ""
Expand Down Expand Up @@ -262,7 +274,7 @@ func (e *Execution) FingerPrint() string {

// GetPipeline returns the current pipeline represented by the specified
// pipeline name and run ID.
func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot string) (*Pipeline, error) {
func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot, bucketSessionInfo string) (*Pipeline, error) {
pipelineContext, err := c.getOrInsertContext(ctx, pipelineName, pipelineContextType, nil)
if err != nil {
return nil, err
Expand All @@ -272,7 +284,8 @@ func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace
keyNamespace: stringValue(namespace),
keyResourceName: stringValue(runResource),
// pipeline root of this run
keyPipelineRoot: stringValue(strings.TrimRight(pipelineRoot, "/") + "/" + path.Join(pipelineName, runID)),
keyPipelineRoot: stringValue(strings.TrimRight(pipelineRoot, "/") + "/" + path.Join(pipelineName, runID)),
keySessionInfoDetails: stringValue(bucketSessionInfo),
}
runContext, err := c.getOrInsertContext(ctx, runID, pipelineRunContextType, metadata)
glog.Infof("Pipeline Run Context: %+v", runContext)
Expand Down Expand Up @@ -464,21 +477,22 @@ func (c *Client) PublishExecution(ctx context.Context, execution *Execution, out

// metadata keys
const (
keyDisplayName = "display_name"
keyTaskName = "task_name"
keyImage = "image"
keyPodName = "pod_name"
keyPodUID = "pod_uid"
keyNamespace = "namespace"
keyResourceName = "resource_name"
keyPipelineRoot = "pipeline_root"
keyCacheFingerPrint = "cache_fingerprint"
keyCachedExecutionID = "cached_execution_id"
keyInputs = "inputs"
keyOutputs = "outputs"
keyParentDagID = "parent_dag_id" // Parent DAG Execution ID.
keyIterationIndex = "iteration_index"
keyIterationCount = "iteration_count"
keyDisplayName = "display_name"
keyTaskName = "task_name"
keyImage = "image"
keyPodName = "pod_name"
keyPodUID = "pod_uid"
keyNamespace = "namespace"
keyResourceName = "resource_name"
keyPipelineRoot = "pipeline_root"
keySessionInfoDetails = "bucket_session_info"
keyCacheFingerPrint = "cache_fingerprint"
keyCachedExecutionID = "cached_execution_id"
keyInputs = "inputs"
keyOutputs = "outputs"
keyParentDagID = "parent_dag_id" // Parent DAG Execution ID.
keyIterationIndex = "iteration_index"
keyIterationCount = "iteration_count"
)

// CreateExecution creates a new MLMD execution under the specified Pipeline.
Expand Down
Loading

0 comments on commit 4e9b4e6

Please sign in to comment.