Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for staging to S3 for snowflake uploads #177

Merged
merged 2 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/flow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,8 @@ jobs:
gotestsum --format testname -- -p 1 ./...
working-directory: ./flow
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION: ${{ secrets.AWS_REGION }}
TEST_BQ_CREDS: ${{ github.workspace }}/bq_service_account.json
TEST_SF_CREDS: ${{ github.workspace }}/snowflake_creds.json
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.vscode
.env

134 changes: 134 additions & 0 deletions flow/connectors/snowflake/avro_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package connsnowflake

import (
"fmt"
"io"
"os"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/linkedin/goavro/v2"
log "github.com/sirupsen/logrus"
)

func createOCFWriter(w io.Writer, avroSchema *model.QRecordAvroSchemaDefinition) (*goavro.OCFWriter, error) {
ocfWriter, err := goavro.NewOCFWriter(goavro.OCFConfig{
W: w,
Schema: avroSchema.Schema,
})
if err != nil {
return nil, fmt.Errorf("failed to create OCF writer: %w", err)
}

return ocfWriter, nil
}

func writeRecordsToOCFWriter(
ocfWriter *goavro.OCFWriter,
records *model.QRecordBatch,
avroSchema *model.QRecordAvroSchemaDefinition) error {
colNames := records.Schema.GetColumnNames()

for _, qRecord := range records.Records {
avroConverter := model.NewQRecordAvroConverter(
qRecord,
qvalue.QDWHTypeSnowflake,
&avroSchema.NullableFields,
colNames,
)
avroMap, err := avroConverter.Convert()
if err != nil {
log.Errorf("failed to convert QRecord to Avro compatible map: %v", err)
return fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err)
}

err = ocfWriter.Append([]interface{}{avroMap})
if err != nil {
log.Errorf("failed to write record to OCF: %v", err)
return fmt.Errorf("failed to write record to OCF: %w", err)
}
}

return nil
}

func WriteRecordsToS3(
records *model.QRecordBatch,
avroSchema *model.QRecordAvroSchemaDefinition,
bucketName, key string) error {
r, w := io.Pipe()

go func() {
defer w.Close()

ocfWriter, err := createOCFWriter(w, avroSchema)
if err != nil {
log.Fatalf("failed to create OCF writer: %v", err)
}

if err := writeRecordsToOCFWriter(ocfWriter, records, avroSchema); err != nil {
log.Fatalf("failed to write records to OCF writer: %v", err)
}
}()

awsSecrets, err := utils.GetAWSSecrets()
if err != nil {
log.Errorf("failed to get AWS secrets: %v", err)
return fmt.Errorf("failed to get AWS secrets: %w", err)
}

// Initialize a session that the SDK will use to load
// credentials from the shared credentials file. (~/.aws/credentials).
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(awsSecrets.Region),
}))

// Create an uploader with the session and default options
uploader := s3manager.NewUploader(sess)

// Upload the file to S3.
result, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(bucketName),
Key: aws.String(key),
Body: r,
})

if err != nil {
log.Errorf("failed to upload file: %v", err)
return fmt.Errorf("failed to upload file: %w", err)
}

log.Infof("file uploaded to, %s", result.Location)

return nil
}

func WriteRecordsToAvroFile(
records *model.QRecordBatch,
avroSchema *model.QRecordAvroSchemaDefinition,
filePath string) error {
file, err := os.Create(filePath)
if err != nil {
log.Errorf("failed to create file: %v", err)
return fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()

ocfWriter, err := createOCFWriter(file, avroSchema)
if err != nil {
log.Errorf("failed to create OCF writer: %v", err)
return err
}

if err := writeRecordsToOCFWriter(ocfWriter, records, avroSchema); err != nil {
log.Errorf("failed to write records to OCF writer: %v", err)
return err
}

return nil
}
44 changes: 31 additions & 13 deletions flow/connectors/snowflake/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package connsnowflake
import (
"database/sql"
"fmt"
"os"
"strings"
"time"

"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -55,12 +55,7 @@ func (c *SnowflakeConnector) SyncQRepRecords(
case protos.QRepSyncMode_QREP_SYNC_MODE_MULTI_INSERT:
return 0, fmt.Errorf("multi-insert sync mode not supported for snowflake")
case protos.QRepSyncMode_QREP_SYNC_MODE_STORAGE_AVRO:
// create a temp directory for storing avro files
tmpDir, err := os.MkdirTemp("", "peerdb-avro")
if err != nil {
return 0, fmt.Errorf("failed to create temp directory: %w", err)
}
avroSync := &SnowflakeAvroSyncMethod{connector: c, localDir: tmpDir}
avroSync := NewSnowflakeAvroSyncMethod(config, c)
return avroSync.SyncQRepRecords(config, partition, tblSchema, records)
default:
return 0, fmt.Errorf("unsupported sync mode: %s", syncMode)
Expand Down Expand Up @@ -153,15 +148,38 @@ func (c *SnowflakeConnector) SetupQRepMetadataTables(config *protos.QRepConfig)
log.Infof("Created table %s", qRepMetadataTableName)

stageName := c.getStageNameForJob(config.FlowJobName)
stageStatement := `
CREATE STAGE IF NOT EXISTS %s
FILE_FORMAT = (TYPE = AVRO);
`
stmt := fmt.Sprintf(stageStatement, stageName)

var createStageStmt string
// if config staging path starts with S3 we need to create an external stage.
if strings.HasPrefix(config.StagingPath, "s3://") {
awsCreds, err := utils.GetAWSSecrets()
if err != nil {
log.Errorf("failed to get AWS secrets: %v", err)
return fmt.Errorf("failed to get AWS secrets: %w", err)
}

credsStr := fmt.Sprintf("CREDENTIALS=(AWS_KEY_ID='%s' AWS_SECRET_KEY='%s')",
awsCreds.AccessKeyID, awsCreds.SecretAccessKey)

stageStatement := `
CREATE OR REPLACE STAGE %s
URL = '%s/%s'
%s
FILE_FORMAT = (TYPE = AVRO);
`
createStageStmt = fmt.Sprintf(stageStatement, stageName, config.StagingPath, config.FlowJobName, credsStr)
} else {
stageStatement := `
CREATE OR REPLACE STAGE %s
FILE_FORMAT = (TYPE = AVRO);
`
createStageStmt = fmt.Sprintf(stageStatement, stageName)
}

// Execute the query
_, err = c.database.Exec(stmt)
_, err = c.database.Exec(createStageStmt)
if err != nil {
log.Errorf("failed to create stage %s: %v", stageName, err)
return fmt.Errorf("failed to create stage %s: %w", stageName, err)
}

Expand Down
106 changes: 50 additions & 56 deletions flow/connectors/snowflake/qrep_avro_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@ import (
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
"github.com/PeerDB-io/peer-flow/model/qvalue"
util "github.com/PeerDB-io/peer-flow/utils"
"github.com/linkedin/goavro/v2"
log "github.com/sirupsen/logrus"
_ "github.com/snowflakedb/gosnowflake"
)

type SnowflakeAvroSyncMethod struct {
config *protos.QRepConfig
connector *SnowflakeConnector
localDir string
}

func NewSnowflakeAvroSyncMethod(connector *SnowflakeConnector, localDir string) *SnowflakeAvroSyncMethod {
func NewSnowflakeAvroSyncMethod(
config *protos.QRepConfig,
connector *SnowflakeConnector) *SnowflakeAvroSyncMethod {
return &SnowflakeAvroSyncMethod{
config: config,
connector: connector,
localDir: localDir,
}
}

Expand Down Expand Up @@ -80,16 +80,55 @@ func (s *SnowflakeAvroSyncMethod) writeToAvroFile(
avroSchema *model.QRecordAvroSchemaDefinition,
partitionID string,
) (string, error) {
localFilePath := fmt.Sprintf("%s/%s.avro", s.localDir, partitionID)
err := WriteRecordsToAvroFile(records, avroSchema, localFilePath)
if err != nil {
return "", fmt.Errorf("failed to write records to Avro file: %w", err)
if s.config.StagingPath == "" {
tmpDir, err := os.MkdirTemp("", "peerdb-avro")
if err != nil {
return "", fmt.Errorf("failed to create temp dir: %w", err)
}

localFilePath := fmt.Sprintf("%s/%s.avro", tmpDir, partitionID)
err = WriteRecordsToAvroFile(records, avroSchema, localFilePath)
if err != nil {
return "", fmt.Errorf("failed to write records to Avro file: %w", err)
}

return localFilePath, nil
} else if strings.HasPrefix(s.config.StagingPath, "s3://") {
// users will have set AWS_REGION, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY
// in their environment.

// Remove s3:// prefix
stagingPath := strings.TrimPrefix(s.config.StagingPath, "s3://")

// Split into bucket and prefix
splitPath := strings.SplitN(stagingPath, "/", 2)

bucket := splitPath[0]
prefix := ""
if len(splitPath) > 1 {
// Remove leading and trailing slashes from prefix
prefix = strings.Trim(splitPath[1], "/")
}

s3Key := fmt.Sprintf("%s/%s/%s.avro", prefix, s.config.FlowJobName, partitionID)

err := WriteRecordsToS3(records, avroSchema, bucket, s3Key)
if err != nil {
return "", fmt.Errorf("failed to write records to S3: %w", err)
}

return "", nil
}

return localFilePath, nil
return "", fmt.Errorf("unsupported staging path: %s", s.config.StagingPath)
}

func (s *SnowflakeAvroSyncMethod) putFileToStage(localFilePath string, stage string) error {
if localFilePath == "" {
log.Infof("no file to put to stage")
return nil
}

putCmd := fmt.Sprintf("PUT file://%s @%s", localFilePath, stage)
if _, err := s.connector.database.Exec(putCmd); err != nil {
return fmt.Errorf("failed to put file to stage: %w", err)
Expand Down Expand Up @@ -157,52 +196,6 @@ func (s *SnowflakeAvroSyncMethod) insertMetadata(
return nil
}

func WriteRecordsToAvroFile(
records *model.QRecordBatch,
avroSchema *model.QRecordAvroSchemaDefinition,
filePath string,
) error {
file, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()

// Create OCF Writer
ocfWriter, err := goavro.NewOCFWriter(goavro.OCFConfig{
W: file,
Schema: avroSchema.Schema,
})
if err != nil {
return fmt.Errorf("failed to create OCF writer: %w", err)
}

colNames := records.Schema.GetColumnNames()

// Write each QRecord to the OCF file
for _, qRecord := range records.Records {
avroConverter := model.NewQRecordAvroConverter(
qRecord,
qvalue.QDWHTypeSnowflake,
&avroSchema.NullableFields,
colNames,
)
avroMap, err := avroConverter.Convert()
if err != nil {
log.Errorf("failed to convert QRecord to Avro compatible map: %v", err)
return fmt.Errorf("failed to convert QRecord to Avro compatible map: %w", err)
}

err = ocfWriter.Append([]interface{}{avroMap})
if err != nil {
log.Errorf("failed to write record to OCF file: %v", err)
return fmt.Errorf("failed to write record to OCF file: %w", err)
}
}

return nil
}

type SnowflakeAvroWriteHandler struct {
db *sql.DB
dstTableName string
Expand All @@ -228,6 +221,7 @@ func NewSnowflakeAvroWriteHandler(
func (s *SnowflakeAvroWriteHandler) HandleAppendMode() error {
//nolint:gosec
copyCmd := fmt.Sprintf("COPY INTO %s FROM @%s %s", s.dstTableName, s.stage, strings.Join(s.copyOpts, ","))
log.Infof("running copy command: %s", copyCmd)
if _, err := s.db.Exec(copyCmd); err != nil {
return fmt.Errorf("failed to run COPY INTO command: %w", err)
}
Expand Down
Loading