forked from sql-machine-learning/gomaxcompute
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
869 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
package gomaxcompute | ||
|
||
import ( | ||
"bytes" | ||
"crypto/hmac" | ||
"crypto/sha1" | ||
"encoding/base64" | ||
"errors" | ||
"fmt" | ||
"io/ioutil" | ||
"net/http" | ||
"net/url" | ||
"strconv" | ||
"strings" | ||
"time" | ||
) | ||
|
||
var ( | ||
errNilBody = errors.New("nil body") | ||
requestGMT, _ = time.LoadLocation("GMT") | ||
) | ||
|
||
const currentProject = "curr_project" | ||
|
||
type pair struct { | ||
k string | ||
v string | ||
} | ||
|
||
// optional: Body, Header | ||
func (conn *odpsConn) request(method, resource string, body []byte, header ...pair) (res *http.Response, err error) { | ||
return conn.requestEndpoint(conn.Endpoint, method, resource, body, header...) | ||
} | ||
|
||
func (conn *odpsConn) requestEndpoint(endpoint, method, resource string, body []byte, header ...pair) (res *http.Response, err error) { | ||
var req *http.Request | ||
url := endpoint + resource | ||
if body != nil { | ||
if req, err = http.NewRequest(method, url, bytes.NewBuffer(body)); err != nil { | ||
return | ||
} | ||
req.Header.Set("Content-Length", strconv.Itoa(len(body))) | ||
} else { | ||
if req, err = http.NewRequest(method, url, nil); err != nil { | ||
return | ||
} | ||
} | ||
|
||
req.Header.Set("x-odps-user-agent", "gomaxcompute/0.0.1") | ||
req.Header.Set("Content-Type", "application/xml") | ||
|
||
if dateStr := req.Header.Get("Date"); dateStr == "" { | ||
gmtTime := time.Now().In(requestGMT).Format(time.RFC1123) | ||
req.Header.Set("Date", gmtTime) | ||
} | ||
// overwrite with user-provide header | ||
if header != nil || len(header) == 0 { | ||
for _, arg := range header { | ||
req.Header.Set(arg.k, arg.v) | ||
} | ||
} | ||
|
||
// fill curr_project | ||
if req.URL.Query().Get(currentProject) == "" { | ||
req.URL.Query().Set(currentProject, conn.Project) | ||
} | ||
conn.sign(req) | ||
return conn.Do(req) | ||
} | ||
|
||
// signature | ||
func (conn *odpsConn) sign(r *http.Request) { | ||
var msg, auth bytes.Buffer | ||
msg.WriteString(r.Method) | ||
msg.WriteByte('\n') | ||
// common header | ||
msg.WriteString(r.Header.Get("Content-MD5")) | ||
msg.WriteByte('\n') | ||
msg.WriteString(r.Header.Get("Content-Type")) | ||
msg.WriteByte('\n') | ||
msg.WriteString(r.Header.Get("Date")) | ||
msg.WriteByte('\n') | ||
// canonical header | ||
for k, v := range r.Header { | ||
lowerK := strings.ToLower(k) | ||
if strings.HasPrefix(lowerK, "x-odps-") { | ||
msg.WriteString(lowerK) | ||
msg.WriteByte(':') | ||
msg.WriteString(strings.Join(v, ",")) | ||
msg.WriteByte('\n') | ||
} | ||
} | ||
|
||
// canonical resource | ||
var canonicalResource bytes.Buffer | ||
epURL, _ := url.Parse(conn.Endpoint) | ||
if strings.HasPrefix(r.URL.Path, epURL.Path) { | ||
canonicalResource.WriteString(r.URL.Path[len(epURL.Path):]) | ||
} else { | ||
canonicalResource.WriteString(r.URL.Path) | ||
} | ||
if urlParams := r.URL.Query(); len(urlParams) > 0 { | ||
first := true | ||
for k, v := range urlParams { | ||
if first { | ||
canonicalResource.WriteByte('?') | ||
first = false | ||
} else { | ||
canonicalResource.WriteByte('&') | ||
} | ||
canonicalResource.WriteString(k) | ||
if v != nil && len(v) > 0 && v[0] != "" { | ||
canonicalResource.WriteByte('=') | ||
canonicalResource.WriteString(v[0]) | ||
} | ||
} | ||
} | ||
msg.WriteString(canonicalResource.String()) | ||
|
||
hasher := hmac.New(sha1.New, []byte(conn.AccessKey)) | ||
hasher.Write(msg.Bytes()) | ||
auth.WriteString("ODPS ") | ||
auth.WriteString(conn.AccessID) | ||
auth.WriteByte(':') | ||
auth.WriteString(base64.StdEncoding.EncodeToString(hasher.Sum(nil))) | ||
r.Header.Set("Authorization", auth.String()) | ||
} | ||
|
||
func (cred *Config) resource(resource string, args ...pair) string { | ||
if args == nil || len(args) == 0 { | ||
return fmt.Sprintf("/projects/%s%s", cred.Project, resource) | ||
} | ||
|
||
ps := url.Values{} | ||
for _, i := range args { | ||
ps.Add(i.k, i.v) | ||
} | ||
return fmt.Sprintf("/projects/%s%s?%s", cred.Project, resource, ps.Encode()) | ||
} | ||
|
||
func parseResponseBody(res *http.Response) ([]byte, error) { | ||
if res == nil { | ||
return nil, errors.New("nil response") | ||
} | ||
if res.StatusCode >= 400 { | ||
return nil, fmt.Errorf("bad stataus:%d", res.StatusCode) | ||
} | ||
if res.Body == nil { | ||
return nil, errNilBody | ||
} | ||
|
||
defer res.Body.Close() | ||
return ioutil.ReadAll(res.Body) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
package gomaxcompute | ||
|
||
import ( | ||
"database/sql/driver" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"net/url" | ||
"strings" | ||
"time" | ||
) | ||
|
||
const ( | ||
waitInteveralMs = 1000 | ||
tunnelHTTPProtocal = "http" | ||
terminated = "Terminated" | ||
methodGet = "GET" | ||
methodPost = "POST" | ||
) | ||
|
||
type odpsConn struct { | ||
*http.Client | ||
*Config | ||
} | ||
|
||
// ODPS does not support transaction | ||
func (*odpsConn) Begin() (driver.Tx, error) { | ||
return nil, nil | ||
} | ||
|
||
// ODPS does not support Prepare | ||
func (*odpsConn) Prepare(query string) (driver.Stmt, error) { | ||
panic("Not implemented") | ||
} | ||
|
||
// Goodps accesses server by restful, so Close() do nth. | ||
func (*odpsConn) Close() error { | ||
return nil | ||
} | ||
|
||
// Implements database/sql/driver.Execer. Notice result is nil | ||
func (conn *odpsConn) Exec(query string, args []driver.Value) (driver.Result, error) { | ||
ins, err := conn.wait(query, args) | ||
if err != nil { | ||
return nil, err | ||
} | ||
_, err = conn.getInstanceResult(ins) | ||
return nil, err | ||
} | ||
|
||
// Implements database/sql/driver.Queryer | ||
func (conn *odpsConn) Query(query string, args []driver.Value) (driver.Rows, error) { | ||
ins, err := conn.wait(query, args) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// get tunnel server | ||
tunnelServer, err := conn.getTunnelServer() | ||
if err != nil { | ||
return nil, err | ||
} | ||
// get meta by tunnel | ||
meta, err := conn.getResultMeta(ins, tunnelServer) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
res, err := conn.getInstanceResult(ins) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if strings.HasPrefix(res, "ODPS-") { | ||
return nil, errors.New(res) | ||
} | ||
return newRows(meta, res) | ||
} | ||
|
||
func (conn *odpsConn) getResultMeta(instance, tunnelServer string) (*resultMeta, error) { | ||
endpoint := fmt.Sprintf("%s://%s", tunnelHTTPProtocal, tunnelServer) | ||
rsc := fmt.Sprintf("/projects/%s/instances/%s", conn.Project, instance) | ||
params := url.Values{} | ||
params.Add(currentProject, conn.Project) | ||
params.Add("downloads", "") | ||
url := rsc + "?" + params.Encode() | ||
|
||
rsp, err := conn.requestEndpoint(endpoint, methodPost, url, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
body, err := parseResponseBody(rsp) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
meta := resultMeta{} | ||
err = json.Unmarshal(body, &meta) | ||
return &meta, err | ||
} | ||
|
||
func (conn *odpsConn) getTunnelServer() (string, error) { | ||
rsp, err := conn.request(methodGet, conn.resource("/tunnel"), nil) | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
url, err := parseResponseBody(rsp) | ||
if err != nil { | ||
return "", err | ||
} | ||
return string(url), nil | ||
} | ||
|
||
func (conn *odpsConn) wait(query string, args []driver.Value) (string, error) { | ||
if len(args) > 0 { | ||
query = fmt.Sprintf(query, args) | ||
} | ||
|
||
ins, err := conn.createInstance(newSQLJob(query)) | ||
if err != nil { | ||
return "", err | ||
} | ||
if err := conn.poll(ins, waitInteveralMs); err != nil { | ||
return "", err | ||
} | ||
return ins, nil | ||
} | ||
|
||
func (conn *odpsConn) poll(instanceID string, interval int) error { | ||
du := time.Duration(interval) * time.Millisecond | ||
for { | ||
status, err := conn.getInstanceStatus(instanceID) | ||
if err != nil { | ||
return err | ||
} | ||
if status == terminated { | ||
break | ||
} | ||
time.Sleep(du) | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package gomaxcompute | ||
|
||
import ( | ||
"database/sql" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestQuery(t *testing.T) { | ||
a := assert.New(t) | ||
db, err := sql.Open("maxcompute", cfg4test.FormatDSN()) | ||
a.NoError(err) | ||
|
||
const sql = `select * from yiyang_test_table1;` | ||
_, err = db.Query(sql) | ||
a.NoError(err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package gomaxcompute | ||
|
||
import ( | ||
"database/sql" | ||
"database/sql/driver" | ||
"net/http" | ||
) | ||
|
||
// register driver | ||
func init() { | ||
sql.Register("maxcompute", &Driver{}) | ||
} | ||
|
||
// impls database/sql/driver.Driver | ||
type Driver struct{} | ||
|
||
func (d Driver) Open(dsn string) (driver.Conn, error) { | ||
cfg, err := ParseDSN(dsn) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return &odpsConn{&http.Client{}, cfg}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package gomaxcompute | ||
|
||
import ( | ||
"database/sql" | ||
"os" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
var cfg4test = &Config{ | ||
AccessID: os.Getenv("ODPS_ACCESS_ID"), | ||
AccessKey: os.Getenv("ODPS_ACCESS_KEY"), | ||
Project: os.Getenv("ODPS_PROJECT"), | ||
Endpoint: os.Getenv("ODPS_ENDPOINT"), | ||
} | ||
|
||
func TestSQLOpen(t *testing.T) { | ||
a := assert.New(t) | ||
db, err := sql.Open("maxcompute", cfg4test.FormatDSN()) | ||
defer db.Close() | ||
a.NoError(err) | ||
} |
Oops, something went wrong.