From bef05afb5af2c3ecd4538204f003c1d972c6f935 Mon Sep 17 00:00:00 2001 From: Humair Khan Date: Wed, 13 Mar 2024 10:53:15 -0400 Subject: [PATCH] UPSTREAM: : fix: handle cached input in objstore This change resolves a bug where when a pipeline step needs to retrieve an artifact input that is cached from a different run, by re-using the default bucket's configuration. Signed-off-by: Humair Khan --- backend/src/v2/component/launcher_v2.go | 23 +++++++++-- backend/src/v2/component/launcher_v2_test.go | 2 +- backend/src/v2/config/env.go | 2 +- backend/src/v2/objectstore/object_store.go | 38 ++++++++++++++----- .../src/v2/objectstore/object_store_test.go | 2 +- 5 files changed, 50 insertions(+), 17 deletions(-) diff --git a/backend/src/v2/component/launcher_v2.go b/backend/src/v2/component/launcher_v2.go index 78749471819..9a3b3824488 100644 --- a/backend/src/v2/component/launcher_v2.go +++ b/backend/src/v2/component/launcher_v2.go @@ -156,12 +156,16 @@ func (l *LauncherV2) Execute(ctx context.Context) (err error) { return err } fingerPrint := execution.FingerPrint() - bucketConfig, err := objectstore.ParseBucketConfig(execution.GetPipeline().GetPipelineRoot()) + bucketSessionInfo, err := objectstore.GetSessionInfoFromString(execution.GetPipeline().GetPipelineBucketSession()) if err != nil { return err } - bucketSessionInfo := execution.GetPipeline().GetPipelineBucketSession() - bucket, err := objectstore.OpenBucket(ctx, l.k8sClient, l.options.Namespace, bucketConfig, bucketSessionInfo) + pipelineRoot := execution.GetPipeline().GetPipelineRoot() + bucketConfig, err := objectstore.ParseBucketConfig(pipelineRoot, bucketSessionInfo) + if err != nil { + return err + } + bucket, err := objectstore.OpenBucket(ctx, l.k8sClient, l.options.Namespace, bucketConfig) if err != nil { return err } @@ -535,12 +539,23 @@ func fetchNonDefaultBuckets( } // TODO: Support multiple artifacts someday, probably through the v2 engine. artifact := artifactList.Artifacts[0] + // The artifact does not belong under the s3 path for this run + // Reasons: + // 1. Artifact is cached from a different run, so it may still be in the default bucket, but under a different run id subpath + // 2. Artifact is imported from a different bucket, or obj store + // a. If imported, artifact bucket can still be specified in kfp-launcher config (not implemented) + // b. If imported, artifact bucket can not be in kfp-launcher config, in this case, return no session and rely on env for aws config if !strings.HasPrefix(artifact.Uri, defaultBucketConfig.PrefixedBucket()) { nonDefaultBucketConfig, err := objectstore.ParseBucketConfigForArtifactURI(artifact.Uri) if err != nil { return nonDefaultBuckets, fmt.Errorf("failed to parse bucketConfig for output artifact %q with uri %q: %w", name, artifact.GetUri(), err) } - nonDefaultBucket, err := objectstore.OpenBucket(ctx, k8sClient, namespace, nonDefaultBucketConfig, "") + // If the run is cached, it will be in the same bucket but under a different path, re-use the default bucket + // session in this case. + if (nonDefaultBucketConfig.Scheme == defaultBucketConfig.Scheme) && (nonDefaultBucketConfig.BucketName == defaultBucketConfig.BucketName) { + nonDefaultBucketConfig.Session = defaultBucketConfig.Session + } + nonDefaultBucket, err := objectstore.OpenBucket(ctx, k8sClient, namespace, nonDefaultBucketConfig) if err != nil { return nonDefaultBuckets, fmt.Errorf("failed to open bucket for output artifact %q with uri %q: %w", name, artifact.GetUri(), err) } diff --git a/backend/src/v2/component/launcher_v2_test.go b/backend/src/v2/component/launcher_v2_test.go index e3e8835ff8b..5f9438b9fb0 100644 --- a/backend/src/v2/component/launcher_v2_test.go +++ b/backend/src/v2/component/launcher_v2_test.go @@ -77,7 +77,7 @@ func Test_executeV2_Parameters(t *testing.T) { fakeMetadataClient := metadata.NewFakeClient() bucket, err := blob.OpenBucket(context.Background(), "gs://test-bucket") assert.Nil(t, err) - bucketConfig, err := objectstore.ParseBucketConfig("gs://test-bucket/pipeline-root/") + bucketConfig, err := objectstore.ParseBucketConfig("gs://test-bucket/pipeline-root/", nil) assert.Nil(t, err) _, _, err = executeV2(context.Background(), test.executorInput, addNumbersComponent, "sh", test.executorArgs, bucket, bucketConfig, fakeMetadataClient, "namespace", fakeKubernetesClientset) diff --git a/backend/src/v2/config/env.go b/backend/src/v2/config/env.go index 6ff13368b07..b5f0fdc5469 100644 --- a/backend/src/v2/config/env.go +++ b/backend/src/v2/config/env.go @@ -125,7 +125,7 @@ type SecretRef struct { func (c *Config) GetBucketSessionInfo() (objectstore.SessionInfo, error) { path := c.DefaultPipelineRoot() - bucketConfig, err := objectstore.ParseBucketConfig(path) + bucketConfig, err := objectstore.ParseBucketPathToConfig(path) if err != nil { return objectstore.SessionInfo{}, err } diff --git a/backend/src/v2/objectstore/object_store.go b/backend/src/v2/objectstore/object_store.go index c2fe39b1efc..1e3369c67d0 100644 --- a/backend/src/v2/objectstore/object_store.go +++ b/backend/src/v2/objectstore/object_store.go @@ -43,15 +43,16 @@ type Config struct { BucketName string Prefix string QueryString string + Session *SessionInfo } -func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace string, config *Config, bucketSessionInfo string) (bucket *blob.Bucket, err error) { +func OpenBucket(ctx context.Context, k8sClient kubernetes.Interface, namespace string, config *Config) (bucket *blob.Bucket, err error) { defer func() { if err != nil { err = fmt.Errorf("Failed to open bucket %q: %w", config.BucketName, err) } }() - sess, err := createBucketSession(ctx, namespace, bucketSessionInfo, k8sClient) + sess, err := createBucketSession(ctx, namespace, config.Session, k8sClient) if err != nil { return nil, fmt.Errorf("Failed to retrieve credentials for bucket %s: %w", config.BucketName, err) } @@ -171,7 +172,17 @@ func DownloadBlob(ctx context.Context, bucket *blob.Bucket, localDir, blobDir st var bucketPattern = regexp.MustCompile(`(^[a-z][a-z0-9]+:///?)([^/?]+)(/[^?]*)?(\?.+)?$`) -func ParseBucketConfig(path string) (*Config, error) { +func ParseBucketConfig(path string, sess *SessionInfo) (*Config, error) { + config, err := ParseBucketPathToConfig(path) + if err != nil { + return nil, err + } + config.Session = sess + + return config, nil +} + +func ParseBucketPathToConfig(path string) (*Config, error) { ms := bucketPattern.FindStringSubmatch(path) if ms == nil || len(ms) != 5 { return nil, fmt.Errorf("parse bucket config failed: unrecognized pipeline root format: %q", path) @@ -340,15 +351,10 @@ type SessionInfo struct { SecretKeyKey string } -func createBucketSession(ctx context.Context, namespace string, sessionInfoJSON string, client kubernetes.Interface) (*session.Session, error) { - if sessionInfoJSON == "" { +func createBucketSession(ctx context.Context, namespace string, sessionInfo *SessionInfo, client kubernetes.Interface) (*session.Session, error) { + if sessionInfo == nil { return nil, nil } - sessionInfo := &SessionInfo{} - err := json.Unmarshal([]byte(sessionInfoJSON), sessionInfo) - if err != nil { - return nil, fmt.Errorf("Encountered error when attempting to unmarshall bucket session properties: %w", err) - } creds, err := getBucketCredential(ctx, client, namespace, sessionInfo.SecretName, sessionInfo.SecretKeyKey, sessionInfo.AccessKeyKey) if err != nil { return nil, err @@ -396,3 +402,15 @@ func getBucketCredential( } return nil, fmt.Errorf("could not find specified keys '%s' or '%s'", bucketAccessKeyKey, bucketSecretKeyKey) } + +func GetSessionInfoFromString(sessionInfoJSON string) (*SessionInfo, error) { + sessionInfo := &SessionInfo{} + if sessionInfoJSON == "" { + return nil, nil + } + err := json.Unmarshal([]byte(sessionInfoJSON), sessionInfo) + if err != nil { + return nil, fmt.Errorf("Encountered error when attempting to unmarshall bucket session properties: %w", err) + } + return sessionInfo, nil +} diff --git a/backend/src/v2/objectstore/object_store_test.go b/backend/src/v2/objectstore/object_store_test.go index 86cd48da521..b88831e755d 100644 --- a/backend/src/v2/objectstore/object_store_test.go +++ b/backend/src/v2/objectstore/object_store_test.go @@ -124,7 +124,7 @@ func Test_parseCloudBucket(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := objectstore.ParseBucketConfig(tt.path) + got, err := objectstore.ParseBucketConfig(tt.path, nil) if (err != nil) != tt.wantErr { t.Errorf("%q: parseCloudBucket() error = %v, wantErr %v", tt.name, err, tt.wantErr) return