diff --git a/connection.go b/connection.go index 0c4e1a6..0d4190a 100644 --- a/connection.go +++ b/connection.go @@ -66,10 +66,18 @@ func (c *Connection) Close() error { return nil } -func (c *Connection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { +func removeLastSemicolon(s string) string { + n := len(s) + if n > 0 && s[n-1] == ';' { + return s[0 : n-1] + } + return s +} + +func (c *Connection) execute(ctx context.Context, query string, args []driver.NamedValue) (*hiveserver2.TExecuteStatementResp, error) { executeReq := hiveserver2.NewTExecuteStatementReq() executeReq.SessionHandle = c.session - executeReq.Statement = query + executeReq.Statement = removeLastSemicolon(query) resp, err := c.thrift.ExecuteStatement(executeReq) if err != nil { @@ -79,22 +87,21 @@ func (c *Connection) QueryContext(ctx context.Context, query string, args []driv if !isSuccessStatus(resp.Status) { return nil, fmt.Errorf("Error from server: %s", resp.Status.String()) } + return resp, nil +} +func (c *Connection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + resp, err := c.execute(ctx, query, args) + if err != nil { + return nil, err + } return newRows(c.thrift, resp.OperationHandle, c.options), nil } func (c *Connection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - executeReq := hiveserver2.NewTExecuteStatementReq() - executeReq.SessionHandle = c.session - executeReq.Statement = query - - resp, err := c.thrift.ExecuteStatement(executeReq) + resp, err := c.execute(ctx, query, args) if err != nil { - return nil, fmt.Errorf("Error in ExecuteStatement: %+v, %v", resp, err) - } - - if !isSuccessStatus(resp.Status) { - return nil, fmt.Errorf("Error from server: %s", resp.Status.String()) + return nil, err } return newHiveResult(resp.OperationHandle), nil } diff --git a/driver_test.go b/driver_test.go index 76b78c2..7ccdb9a 100644 --- a/driver_test.go +++ b/driver_test.go @@ -36,7 +36,7 @@ func TestQuery(t *testing.T) { func TestColumnName(t *testing.T) { a := assert.New(t) db, _ := sql.Open("hive", "127.0.0.1:10000/churn") - rows, err := db.Query("SELECT customerID, gender FROM train") + rows, err := db.Query("SELECT customerID, gender FROM train;") assert.Nil(t, err) defer db.Close() defer rows.Close()