Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
weiguoz committed May 16, 2019
1 parent f1d7a0c commit bfd56ae
Show file tree
Hide file tree
Showing 14 changed files with 869 additions and 0 deletions.
154 changes: 154 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package gomaxcompute

import (
"bytes"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)

var (
errNilBody = errors.New("nil body")
requestGMT, _ = time.LoadLocation("GMT")
)

const currentProject = "curr_project"

type pair struct {
k string
v string
}

// optional: Body, Header
func (conn *odpsConn) request(method, resource string, body []byte, header ...pair) (res *http.Response, err error) {
return conn.requestEndpoint(conn.Endpoint, method, resource, body, header...)
}

func (conn *odpsConn) requestEndpoint(endpoint, method, resource string, body []byte, header ...pair) (res *http.Response, err error) {
var req *http.Request
url := endpoint + resource
if body != nil {
if req, err = http.NewRequest(method, url, bytes.NewBuffer(body)); err != nil {
return
}
req.Header.Set("Content-Length", strconv.Itoa(len(body)))
} else {
if req, err = http.NewRequest(method, url, nil); err != nil {
return
}
}

req.Header.Set("x-odps-user-agent", "gomaxcompute/0.0.1")
req.Header.Set("Content-Type", "application/xml")

if dateStr := req.Header.Get("Date"); dateStr == "" {
gmtTime := time.Now().In(requestGMT).Format(time.RFC1123)
req.Header.Set("Date", gmtTime)
}
// overwrite with user-provide header
if header != nil || len(header) == 0 {
for _, arg := range header {
req.Header.Set(arg.k, arg.v)
}
}

// fill curr_project
if req.URL.Query().Get(currentProject) == "" {
req.URL.Query().Set(currentProject, conn.Project)
}
conn.sign(req)
return conn.Do(req)
}

// signature
func (conn *odpsConn) sign(r *http.Request) {
var msg, auth bytes.Buffer
msg.WriteString(r.Method)
msg.WriteByte('\n')
// common header
msg.WriteString(r.Header.Get("Content-MD5"))
msg.WriteByte('\n')
msg.WriteString(r.Header.Get("Content-Type"))
msg.WriteByte('\n')
msg.WriteString(r.Header.Get("Date"))
msg.WriteByte('\n')
// canonical header
for k, v := range r.Header {
lowerK := strings.ToLower(k)
if strings.HasPrefix(lowerK, "x-odps-") {
msg.WriteString(lowerK)
msg.WriteByte(':')
msg.WriteString(strings.Join(v, ","))
msg.WriteByte('\n')
}
}

// canonical resource
var canonicalResource bytes.Buffer
epURL, _ := url.Parse(conn.Endpoint)
if strings.HasPrefix(r.URL.Path, epURL.Path) {
canonicalResource.WriteString(r.URL.Path[len(epURL.Path):])
} else {
canonicalResource.WriteString(r.URL.Path)
}
if urlParams := r.URL.Query(); len(urlParams) > 0 {
first := true
for k, v := range urlParams {
if first {
canonicalResource.WriteByte('?')
first = false
} else {
canonicalResource.WriteByte('&')
}
canonicalResource.WriteString(k)
if v != nil && len(v) > 0 && v[0] != "" {
canonicalResource.WriteByte('=')
canonicalResource.WriteString(v[0])
}
}
}
msg.WriteString(canonicalResource.String())

hasher := hmac.New(sha1.New, []byte(conn.AccessKey))
hasher.Write(msg.Bytes())
auth.WriteString("ODPS ")
auth.WriteString(conn.AccessID)
auth.WriteByte(':')
auth.WriteString(base64.StdEncoding.EncodeToString(hasher.Sum(nil)))
r.Header.Set("Authorization", auth.String())
}

func (cred *Config) resource(resource string, args ...pair) string {
if args == nil || len(args) == 0 {
return fmt.Sprintf("/projects/%s%s", cred.Project, resource)
}

ps := url.Values{}
for _, i := range args {
ps.Add(i.k, i.v)
}
return fmt.Sprintf("/projects/%s%s?%s", cred.Project, resource, ps.Encode())
}

func parseResponseBody(res *http.Response) ([]byte, error) {
if res == nil {
return nil, errors.New("nil response")
}
if res.StatusCode >= 400 {
return nil, fmt.Errorf("bad stataus:%d", res.StatusCode)
}
if res.Body == nil {
return nil, errNilBody
}

defer res.Body.Close()
return ioutil.ReadAll(res.Body)
}
143 changes: 143 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package gomaxcompute

import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)

const (
waitInteveralMs = 1000
tunnelHTTPProtocal = "http"
terminated = "Terminated"
methodGet = "GET"
methodPost = "POST"
)

type odpsConn struct {
*http.Client
*Config
}

// ODPS does not support transaction
func (*odpsConn) Begin() (driver.Tx, error) {
return nil, nil
}

// ODPS does not support Prepare
func (*odpsConn) Prepare(query string) (driver.Stmt, error) {
panic("Not implemented")
}

// Goodps accesses server by restful, so Close() do nth.
func (*odpsConn) Close() error {
return nil
}

// Implements database/sql/driver.Execer. Notice result is nil
func (conn *odpsConn) Exec(query string, args []driver.Value) (driver.Result, error) {
ins, err := conn.wait(query, args)
if err != nil {
return nil, err
}
_, err = conn.getInstanceResult(ins)
return nil, err
}

// Implements database/sql/driver.Queryer
func (conn *odpsConn) Query(query string, args []driver.Value) (driver.Rows, error) {
ins, err := conn.wait(query, args)
if err != nil {
return nil, err
}

// get tunnel server
tunnelServer, err := conn.getTunnelServer()
if err != nil {
return nil, err
}
// get meta by tunnel
meta, err := conn.getResultMeta(ins, tunnelServer)
if err != nil {
return nil, err
}

res, err := conn.getInstanceResult(ins)
if err != nil {
return nil, err
}
if strings.HasPrefix(res, "ODPS-") {
return nil, errors.New(res)
}
return newRows(meta, res)
}

func (conn *odpsConn) getResultMeta(instance, tunnelServer string) (*resultMeta, error) {
endpoint := fmt.Sprintf("%s://%s", tunnelHTTPProtocal, tunnelServer)
rsc := fmt.Sprintf("/projects/%s/instances/%s", conn.Project, instance)
params := url.Values{}
params.Add(currentProject, conn.Project)
params.Add("downloads", "")
url := rsc + "?" + params.Encode()

rsp, err := conn.requestEndpoint(endpoint, methodPost, url, nil)
if err != nil {
return nil, err
}
body, err := parseResponseBody(rsp)
if err != nil {
return nil, err
}

meta := resultMeta{}
err = json.Unmarshal(body, &meta)
return &meta, err
}

func (conn *odpsConn) getTunnelServer() (string, error) {
rsp, err := conn.request(methodGet, conn.resource("/tunnel"), nil)
if err != nil {
return "", err
}

url, err := parseResponseBody(rsp)
if err != nil {
return "", err
}
return string(url), nil
}

func (conn *odpsConn) wait(query string, args []driver.Value) (string, error) {
if len(args) > 0 {
query = fmt.Sprintf(query, args)
}

ins, err := conn.createInstance(newSQLJob(query))
if err != nil {
return "", err
}
if err := conn.poll(ins, waitInteveralMs); err != nil {
return "", err
}
return ins, nil
}

func (conn *odpsConn) poll(instanceID string, interval int) error {
du := time.Duration(interval) * time.Millisecond
for {
status, err := conn.getInstanceStatus(instanceID)
if err != nil {
return err
}
if status == terminated {
break
}
time.Sleep(du)
}
return nil
}
18 changes: 18 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package gomaxcompute

import (
"database/sql"
"testing"

"github.com/stretchr/testify/assert"
)

func TestQuery(t *testing.T) {
a := assert.New(t)
db, err := sql.Open("maxcompute", cfg4test.FormatDSN())
a.NoError(err)

const sql = `select * from yiyang_test_table1;`
_, err = db.Query(sql)
a.NoError(err)
}
23 changes: 23 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package gomaxcompute

import (
"database/sql"
"database/sql/driver"
"net/http"
)

// register driver
func init() {
sql.Register("maxcompute", &Driver{})
}

// impls database/sql/driver.Driver
type Driver struct{}

func (d Driver) Open(dsn string) (driver.Conn, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return &odpsConn{&http.Client{}, cfg}, nil
}
23 changes: 23 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package gomaxcompute

import (
"database/sql"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

var cfg4test = &Config{
AccessID: os.Getenv("ODPS_ACCESS_ID"),
AccessKey: os.Getenv("ODPS_ACCESS_KEY"),
Project: os.Getenv("ODPS_PROJECT"),
Endpoint: os.Getenv("ODPS_ENDPOINT"),
}

func TestSQLOpen(t *testing.T) {
a := assert.New(t)
db, err := sql.Open("maxcompute", cfg4test.FormatDSN())
defer db.Close()
a.NoError(err)
}
Loading

0 comments on commit bfd56ae

Please sign in to comment.