Skip to content

Commit 9112ac7

Browse files
committed
Changed staging operation so it doesn't return
Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
1 parent b50616b commit 9112ac7

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed

connection.go

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,22 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
126126
}
127127
resp, err := c.client.GetResultSetMetadata(ctx, &req)
128128
if err != nil {
129-
return nil, dbsqlerrint.NewDriverError(ctx, "Error performing staging operation", err)
129+
return nil, dbsqlerrint.NewDriverError(ctx, "error performing staging operation", err)
130130
}
131131
isStagingOperation = *resp.IsStagingOperation
132132
}
133133
if isStagingOperation {
134134
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
135135
row, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
136136
if err != nil {
137-
return nil, dbsqlerrint.NewDriverError(ctx, "Error reading row.", err)
137+
return nil, dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
138+
}
139+
err = c.ExecStagingOperation(ctx, row)
140+
if err != nil {
141+
return nil, err
138142
}
139-
return c.ExecStagingOperation(ctx, row)
140143
} else {
141-
return nil, dbsqlerrint.NewDriverError(ctx, "Staging ctx must be provided.", nil)
144+
return nil, dbsqlerrint.NewDriverError(ctx, "staging ctx must be provided.", nil)
142145
}
143146
}
144147

@@ -174,39 +177,40 @@ func Succeeded(response *http.Response) bool {
174177
return false
175178
}
176179

177-
func (c *conn) HandleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) (driver.Result, error) {
180+
func (c *conn) HandleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError {
178181
if localFile == "" {
179-
return nil, dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil)
182+
return dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil)
180183
}
181184
client := &http.Client{}
182185

183186
dat, err := os.ReadFile(localFile)
184187

185-
req, _ := http.NewRequest("PUT", presignedUrl, bytes.NewReader(dat))
186-
187188
if err != nil {
188-
return nil, err
189+
return dbsqlerrint.NewDriverError(ctx, "error reading local file", err)
189190
}
191+
192+
req, _ := http.NewRequest("PUT", presignedUrl, bytes.NewReader(dat))
193+
190194
for k, v := range headers {
191195
req.Header.Set(k, v)
192196
}
193197
res, err := client.Do(req)
194198
if err != nil {
195-
return nil, err
199+
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
196200
}
197201
defer res.Body.Close()
198202
content, err := io.ReadAll(res.Body)
199203

200204
if err != nil || !Succeeded(res) {
201-
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil)
205+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil)
202206
}
203-
return driver.ResultNoRows, nil
207+
return nil
204208

205209
}
206210

207-
func (c *conn) HandleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) (driver.Result, error) {
211+
func (c *conn) HandleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError {
208212
if localFile == "" {
209-
return nil, fmt.Errorf("cannot perform GET without specifying a local_file")
213+
return dbsqlerrint.NewDriverError(ctx, "cannot perform GET without specifying a local_file", nil)
210214
}
211215
client := &http.Client{}
212216
req, _ := http.NewRequest("GET", presignedUrl, nil)
@@ -216,41 +220,40 @@ func (c *conn) HandleStagingGet(ctx context.Context, presignedUrl string, header
216220
}
217221
res, err := client.Do(req)
218222
if err != nil {
219-
return nil, err
223+
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
220224
}
221225
defer res.Body.Close()
222226
content, err := io.ReadAll(res.Body)
223227

224228
if err != nil || !Succeeded(res) {
225-
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil)
229+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil)
226230
}
227231

228232
err = os.WriteFile(localFile, content, 0644) //nolint:gosec
229233
if err != nil {
230-
return nil, err
234+
return dbsqlerrint.NewDriverError(ctx, "error writing local file", err)
231235
}
232-
return driver.ResultNoRows, nil
233-
236+
return nil
234237
}
235238

236-
func (c *conn) HandleStagingDelete(ctx context.Context, presignedUrl string, headers map[string]string) (driver.Result, error) {
239+
func (c *conn) HandleStagingDelete(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError {
237240
client := &http.Client{}
238241
req, _ := http.NewRequest("DELETE", presignedUrl, nil)
239242
for k, v := range headers {
240243
req.Header.Set(k, v)
241244
}
242245
res, err := client.Do(req)
243246
if err != nil {
244-
return nil, err
247+
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
245248
}
246249
defer res.Body.Close()
247250
content, err := io.ReadAll(res.Body)
248251

249252
if err != nil || !Succeeded(res) {
250-
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil)
253+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil)
251254
}
252255

253-
return driver.ResultNoRows, nil
256+
return nil
254257
}
255258

256259
func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool {
@@ -278,29 +281,29 @@ func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) boo
278281

279282
func (c *conn) ExecStagingOperation(
280283
ctx context.Context,
281-
row driver.Rows) (driver.Result, error) {
284+
row driver.Rows) dbsqlerr.DBError {
282285

283286
var sqlRow []driver.Value
284287
colNames := row.Columns()
285288
sqlRow = make([]driver.Value, len(colNames))
286289
err := row.Next(sqlRow)
287290
if err != nil {
288-
return nil, dbsqlerrint.NewDriverError(ctx, "Error fetching staging operation results", err)
291+
return dbsqlerrint.NewDriverError(ctx, "error fetching staging operation results", err)
289292
}
290293
var stringValues []string = make([]string, 4)
291294
for i := range stringValues {
292295
if s, ok := sqlRow[i].(string); ok {
293296
stringValues[i] = s
294297
} else {
295-
return nil, dbsqlerrint.NewDriverError(ctx, "Received unexpected response from the server.", nil)
298+
return dbsqlerrint.NewDriverError(ctx, "received unexpected response from the server.", nil)
296299
}
297300
}
298301
operation := stringValues[0]
299302
presignedUrl := stringValues[1]
300303
headersByteArr := []byte(stringValues[2])
301304
var headers map[string]string
302305
if err := json.Unmarshal(headersByteArr, &headers); err != nil {
303-
return nil, err
306+
return dbsqlerrint.NewDriverError(ctx, "error parsing server response.", nil)
304307
}
305308
localFile := stringValues[3]
306309
stagingAllowedLocalPaths := driverctx.StagingPathsFromContext(ctx)
@@ -309,18 +312,18 @@ func (c *conn) ExecStagingOperation(
309312
if localPathIsAllowed(stagingAllowedLocalPaths, localFile) {
310313
return c.HandleStagingPut(ctx, presignedUrl, headers, localFile)
311314
} else {
312-
return nil, dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
315+
return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
313316
}
314317
case "GET":
315318
if localPathIsAllowed(stagingAllowedLocalPaths, localFile) {
316319
return c.HandleStagingGet(ctx, presignedUrl, headers, localFile)
317320
} else {
318-
return nil, dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
321+
return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
319322
}
320323
case "DELETE":
321324
return c.HandleStagingDelete(ctx, presignedUrl, headers)
322325
default:
323-
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("operation %s is not supported. Supported operations are GET, PUT, and REMOVE", operation), nil)
326+
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("operation %s is not supported. Supported operations are GET, PUT, and REMOVE", operation), nil)
324327
}
325328
}
326329

statement_test.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,15 @@ func TestStmt_ExecContext(t *testing.T) {
8787
}
8888
return getOperationStatusResp, nil
8989
}
90+
fetchResultSetMetadata := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (_r *cli_service.TGetResultSetMetadataResp, _err error) {
91+
var b = false
92+
return &cli_service.TGetResultSetMetadataResp{IsStagingOperation: &b}, nil
93+
}
9094

9195
testClient := &client.TestClient{
92-
FnExecuteStatement: executeStatement,
93-
FnGetOperationStatus: getOperationStatus,
96+
FnExecuteStatement: executeStatement,
97+
FnGetOperationStatus: getOperationStatus,
98+
FnGetResultSetMetadata: fetchResultSetMetadata,
9499
}
95100
testConn := &conn{
96101
session: getTestSession(),

0 commit comments

Comments
 (0)