Skip to content

Commit 53aeb09

Browse files
authored
Add support for staging to S3 for snowflake uploads (#177)
1 parent bab2fa0 commit 53aeb09

File tree

13 files changed

+374
-104
lines changed

13 files changed

+374
-104
lines changed

.github/workflows/flow.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,8 @@ jobs:
6060
gotestsum --format testname -- -p 1 ./...
6161
working-directory: ./flow
6262
env:
63+
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
64+
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
65+
AWS_REGION: ${{ secrets.AWS_REGION }}
6366
TEST_BQ_CREDS: ${{ github.workspace }}/bq_service_account.json
6467
TEST_SF_CREDS: ${{ github.workspace }}/snowflake_creds.json

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
.vscode
2+
.env
3+
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package connsnowflake
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"os"
7+
8+
"github.com/PeerDB-io/peer-flow/connectors/utils"
9+
"github.com/PeerDB-io/peer-flow/model"
10+
"github.com/PeerDB-io/peer-flow/model/qvalue"
11+
12+
"github.com/aws/aws-sdk-go/aws"
13+
"github.com/aws/aws-sdk-go/aws/session"
14+
"github.com/aws/aws-sdk-go/service/s3/s3manager"
15+
"github.com/linkedin/goavro/v2"
16+
log "github.com/sirupsen/logrus"
17+
)
18+
19+
func createOCFWriter(w io.Writer, avroSchema *model.QRecordAvroSchemaDefinition) (*goavro.OCFWriter, error) {
20+
ocfWriter, err := goavro.NewOCFWriter(goavro.OCFConfig{
21+
W: w,
22+
Schema: avroSchema.Schema,
23+
})
24+
if err != nil {
25+
return nil, fmt.Errorf("failed to create OCF writer: %w", err)
26+
}
27+
28+
return ocfWriter, nil
29+
}
30+
31+
func writeRecordsToOCFWriter(
32+
ocfWriter *goavro.OCFWriter,
33+
records *model.QRecordBatch,
34+
avroSchema *model.QRecordAvroSchemaDefinition) error {
35+
colNames := records.Schema.GetColumnNames()
36+
37+
for _, qRecord := range records.Records {
38+
avroConverter := model.NewQRecordAvroConverter(
39+
qRecord,
40+
qvalue.QDWHTypeSnowflake,
41+
&avroSchema.NullableFields,
42+
colNames,
43+
)
44+
avroMap, err := avroConverter.Convert()
45+
if err != nil {
46+
log.Errorf("failed to convert QRecord to Avro compatible map: %v", err)
47+
return fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err)
48+
}
49+
50+
err = ocfWriter.Append([]interface{}{avroMap})
51+
if err != nil {
52+
log.Errorf("failed to write record to OCF: %v", err)
53+
return fmt.Errorf("failed to write record to OCF: %w", err)
54+
}
55+
}
56+
57+
return nil
58+
}
59+
60+
func WriteRecordsToS3(
61+
records *model.QRecordBatch,
62+
avroSchema *model.QRecordAvroSchemaDefinition,
63+
bucketName, key string) error {
64+
r, w := io.Pipe()
65+
66+
go func() {
67+
defer w.Close()
68+
69+
ocfWriter, err := createOCFWriter(w, avroSchema)
70+
if err != nil {
71+
log.Fatalf("failed to create OCF writer: %v", err)
72+
}
73+
74+
if err := writeRecordsToOCFWriter(ocfWriter, records, avroSchema); err != nil {
75+
log.Fatalf("failed to write records to OCF writer: %v", err)
76+
}
77+
}()
78+
79+
awsSecrets, err := utils.GetAWSSecrets()
80+
if err != nil {
81+
log.Errorf("failed to get AWS secrets: %v", err)
82+
return fmt.Errorf("failed to get AWS secrets: %w", err)
83+
}
84+
85+
// Initialize a session that the SDK will use to load
86+
// credentials from the shared credentials file. (~/.aws/credentials).
87+
sess := session.Must(session.NewSession(&aws.Config{
88+
Region: aws.String(awsSecrets.Region),
89+
}))
90+
91+
// Create an uploader with the session and default options
92+
uploader := s3manager.NewUploader(sess)
93+
94+
// Upload the file to S3.
95+
result, err := uploader.Upload(&s3manager.UploadInput{
96+
Bucket: aws.String(bucketName),
97+
Key: aws.String(key),
98+
Body: r,
99+
})
100+
101+
if err != nil {
102+
log.Errorf("failed to upload file: %v", err)
103+
return fmt.Errorf("failed to upload file: %w", err)
104+
}
105+
106+
log.Infof("file uploaded to, %s", result.Location)
107+
108+
return nil
109+
}
110+
111+
func WriteRecordsToAvroFile(
112+
records *model.QRecordBatch,
113+
avroSchema *model.QRecordAvroSchemaDefinition,
114+
filePath string) error {
115+
file, err := os.Create(filePath)
116+
if err != nil {
117+
log.Errorf("failed to create file: %v", err)
118+
return fmt.Errorf("failed to create file: %w", err)
119+
}
120+
defer file.Close()
121+
122+
ocfWriter, err := createOCFWriter(file, avroSchema)
123+
if err != nil {
124+
log.Errorf("failed to create OCF writer: %v", err)
125+
return err
126+
}
127+
128+
if err := writeRecordsToOCFWriter(ocfWriter, records, avroSchema); err != nil {
129+
log.Errorf("failed to write records to OCF writer: %v", err)
130+
return err
131+
}
132+
133+
return nil
134+
}

flow/connectors/snowflake/qrep.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ package connsnowflake
33
import (
44
"database/sql"
55
"fmt"
6-
"os"
76
"strings"
87
"time"
98

9+
"github.com/PeerDB-io/peer-flow/connectors/utils"
1010
"github.com/PeerDB-io/peer-flow/generated/protos"
1111
"github.com/PeerDB-io/peer-flow/model"
1212
log "github.com/sirupsen/logrus"
@@ -55,12 +55,7 @@ func (c *SnowflakeConnector) SyncQRepRecords(
5555
case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT:
5656
return 0, fmt.Errorf("multi-insert sync mode not supported for snowflake")
5757
case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO:
58-
// create a temp directory for storing avro files
59-
tmpDir, err := os.MkdirTemp("", "peerdb-avro")
60-
if err != nil {
61-
return 0, fmt.Errorf("failed to create temp directory: %w", err)
62-
}
63-
avroSync := &SnowflakeAvroSyncMethod{connector: c, localDir: tmpDir}
58+
avroSync := NewSnowflakeAvroSyncMethod(config, c)
6459
return avroSync.SyncQRepRecords(config, partition, tblSchema, records)
6560
default:
6661
return 0, fmt.Errorf("unsupported sync mode: %s", syncMode)
@@ -153,15 +148,38 @@ func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig)
153148
log.Infof("Created table %s", qRepMetadataTableName)
154149

155150
stageName := c.getStageNameForJob(config.FlowJobName)
156-
stageStatement := `
157-
CREATE STAGE IF NOT EXISTS %s
158-
FILE_FORMAT = (TYPE = AVRO);
159-
`
160-
stmt := fmt.Sprintf(stageStatement, stageName)
151+
152+
var createStageStmt string
153+
// if config staging path starts with S3 we need to create an external stage.
154+
if strings.HasPrefix(config.StagingPath, "s3://") {
155+
awsCreds, err := utils.GetAWSSecrets()
156+
if err != nil {
157+
log.Errorf("failed to get AWS secrets: %v", err)
158+
return fmt.Errorf("failed to get AWS secrets: %w", err)
159+
}
160+
161+
credsStr := fmt.Sprintf("CREDENTIALS=(AWS_KEY_ID='%s' AWS_SECRET_KEY='%s')",
162+
awsCreds.AccessKeyID, awsCreds.SecretAccessKey)
163+
164+
stageStatement := `
165+
CREATE OR REPLACE STAGE %s
166+
URL = '%s/%s'
167+
%s
168+
FILE_FORMAT = (TYPE = AVRO);
169+
`
170+
createStageStmt = fmt.Sprintf(stageStatement, stageName, config.StagingPath, config.FlowJobName, credsStr)
171+
} else {
172+
stageStatement := `
173+
CREATE OR REPLACE STAGE %s
174+
FILE_FORMAT = (TYPE = AVRO);
175+
`
176+
createStageStmt = fmt.Sprintf(stageStatement, stageName)
177+
}
161178

162179
// Execute the query
163-
_, err = c.database.Exec(stmt)
180+
_, err = c.database.Exec(createStageStmt)
164181
if err != nil {
182+
log.Errorf("failed to create stage %s: %v", stageName, err)
165183
return fmt.Errorf("failed to create stage %s: %w", stageName, err)
166184
}
167185

flow/connectors/snowflake/qrep_avro_sync.go

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@ import (
1010
"github.com/PeerDB-io/peer-flow/connectors/utils"
1111
"github.com/PeerDB-io/peer-flow/generated/protos"
1212
"github.com/PeerDB-io/peer-flow/model"
13-
"github.com/PeerDB-io/peer-flow/model/qvalue"
1413
util "github.com/PeerDB-io/peer-flow/utils"
15-
"github.com/linkedin/goavro/v2"
1614
log "github.com/sirupsen/logrus"
1715
_ "github.com/snowflakedb/gosnowflake"
1816
)
1917

2018
type SnowflakeAvroSyncMethod struct {
19+
config *protos.QRepConfig
2120
connector *SnowflakeConnector
22-
localDir string
2321
}
2422

25-
func NewSnowflakeAvroSyncMethod(connector *SnowflakeConnector, localDir string) *SnowflakeAvroSyncMethod {
23+
func NewSnowflakeAvroSyncMethod(
24+
config *protos.QRepConfig,
25+
connector *SnowflakeConnector) *SnowflakeAvroSyncMethod {
2626
return &SnowflakeAvroSyncMethod{
27+
config: config,
2728
connector: connector,
28-
localDir: localDir,
2929
}
3030
}
3131

@@ -80,16 +80,55 @@ func (s *SnowflakeAvroSyncMethod) writeToAvroFile(
8080
avroSchema *model.QRecordAvroSchemaDefinition,
8181
partitionID string,
8282
) (string, error) {
83-
localFilePath := fmt.Sprintf("%s/%s.avro", s.localDir, partitionID)
84-
err := WriteRecordsToAvroFile(records, avroSchema, localFilePath)
85-
if err != nil {
86-
return "", fmt.Errorf("failed to write records to Avro file: %w", err)
83+
if s.config.StagingPath == "" {
84+
tmpDir, err := os.MkdirTemp("", "peerdb-avro")
85+
if err != nil {
86+
return "", fmt.Errorf("failed to create temp dir: %w", err)
87+
}
88+
89+
localFilePath := fmt.Sprintf("%s/%s.avro", tmpDir, partitionID)
90+
err = WriteRecordsToAvroFile(records, avroSchema, localFilePath)
91+
if err != nil {
92+
return "", fmt.Errorf("failed to write records to Avro file: %w", err)
93+
}
94+
95+
return localFilePath, nil
96+
} else if strings.HasPrefix(s.config.StagingPath, "s3://") {
97+
// users will have set AWS_REGION, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY
98+
// in their environment.
99+
100+
// Remove s3:// prefix
101+
stagingPath := strings.TrimPrefix(s.config.StagingPath, "s3://")
102+
103+
// Split into bucket and prefix
104+
splitPath := strings.SplitN(stagingPath, "/", 2)
105+
106+
bucket := splitPath[0]
107+
prefix := ""
108+
if len(splitPath) > 1 {
109+
// Remove leading and trailing slashes from prefix
110+
prefix = strings.Trim(splitPath[1], "/")
111+
}
112+
113+
s3Key := fmt.Sprintf("%s/%s/%s.avro", prefix, s.config.FlowJobName, partitionID)
114+
115+
err := WriteRecordsToS3(records, avroSchema, bucket, s3Key)
116+
if err != nil {
117+
return "", fmt.Errorf("failed to write records to S3: %w", err)
118+
}
119+
120+
return "", nil
87121
}
88122

89-
return localFilePath, nil
123+
return "", fmt.Errorf("unsupported staging path: %s", s.config.StagingPath)
90124
}
91125

92126
func (s *SnowflakeAvroSyncMethod) putFileToStage(localFilePath string, stage string) error {
127+
if localFilePath == "" {
128+
log.Infof("no file to put to stage")
129+
return nil
130+
}
131+
93132
putCmd := fmt.Sprintf("PUT file://%s @%s", localFilePath, stage)
94133
if _, err := s.connector.database.Exec(putCmd); err != nil {
95134
return fmt.Errorf("failed to put file to stage: %w", err)
@@ -157,52 +196,6 @@ func (s *SnowflakeAvroSyncMethod) insertMetadata(
157196
return nil
158197
}
159198

160-
func WriteRecordsToAvroFile(
161-
records *model.QRecordBatch,
162-
avroSchema *model.QRecordAvroSchemaDefinition,
163-
filePath string,
164-
) error {
165-
file, err := os.Create(filePath)
166-
if err != nil {
167-
return fmt.Errorf("failed to create file: %w", err)
168-
}
169-
defer file.Close()
170-
171-
// Create OCF Writer
172-
ocfWriter, err := goavro.NewOCFWriter(goavro.OCFConfig{
173-
W: file,
174-
Schema: avroSchema.Schema,
175-
})
176-
if err != nil {
177-
return fmt.Errorf("failed to create OCF writer: %w", err)
178-
}
179-
180-
colNames := records.Schema.GetColumnNames()
181-
182-
// Write each QRecord to the OCF file
183-
for _, qRecord := range records.Records {
184-
avroConverter := model.NewQRecordAvroConverter(
185-
qRecord,
186-
qvalue.QDWHTypeSnowflake,
187-
&avroSchema.NullableFields,
188-
colNames,
189-
)
190-
avroMap, err := avroConverter.Convert()
191-
if err != nil {
192-
log.Errorf("failed to convert QRecord to Avro compatible map: %v", err)
193-
return fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err)
194-
}
195-
196-
err = ocfWriter.Append([]interface{}{avroMap})
197-
if err != nil {
198-
log.Errorf("failed to write record to OCF file: %v", err)
199-
return fmt.Errorf("failed to write record to OCF file: %w", err)
200-
}
201-
}
202-
203-
return nil
204-
}
205-
206199
type SnowflakeAvroWriteHandler struct {
207200
db *sql.DB
208201
dstTableName string
@@ -228,6 +221,7 @@ func NewSnowflakeAvroWriteHandler(
228221
func (s *SnowflakeAvroWriteHandler) HandleAppendMode() error {
229222
//nolint:gosec
230223
copyCmd := fmt.Sprintf("COPY INTO %s FROM @%s %s", s.dstTableName, s.stage, strings.Join(s.copyOpts, ","))
224+
log.Infof("running copy command: %s", copyCmd)
231225
if _, err := s.db.Exec(copyCmd); err != nil {
232226
return fmt.Errorf("failed to run COPY INTO command: %w", err)
233227
}

0 commit comments

Comments
 (0)