Skip to content

Commit

Permalink
support auth mechanism (#51)
Browse files Browse the repository at this point in the history
* support auth mechnism

* polish code

* follow comment
  • Loading branch information
Yancey1989 authored Aug 8, 2019
1 parent 1e12c84 commit abe07b6
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 22 deletions.
34 changes: 24 additions & 10 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"github.com/apache/thrift/lib/go/thrift"
"fmt"

"github.com/apache/thrift/lib/go/thrift"
bgohive "github.com/beltran/gohive"
hiveserver2 "github.com/sql-machine-learning/gohive/hiveserver2/gen-go/tcliservice"
)

Expand All @@ -17,23 +18,36 @@ func (d drv) Open(dsn string) (driver.Conn, error) {
if err != nil {
return nil, err
}
transport, err := thrift.NewTSocket(cfg.Addr)
socket, err := thrift.NewTSocket(cfg.Addr)
if err != nil {
return nil, err
}

if err := transport.Open(); err != nil {
return nil, err
var transport thrift.TTransport
if cfg.Auth == "NOSASL" {
transport = thrift.NewTBufferedTransport(socket, 4096)
if transport == nil {
return nil, fmt.Errorf("BufferedTransport is nil")
}
} else if cfg.Auth == "PLAIN" || cfg.Auth == "GSSAPI" || cfg.Auth == "LDAP" {
saslCfg := map[string]string{
"username": cfg.User,
"password": cfg.Passwd,
}
transport, err = bgohive.NewTSaslTransport(socket, cfg.Addr, cfg.Auth, saslCfg)
if err != nil {
return nil, fmt.Errorf("create SasalTranposrt failed: %v", err)
}
} else {
return nil, fmt.Errorf("unrecognized auth mechanism: %s", cfg.Auth)
}

if transport == nil {
return nil, errors.New("nil thrift transport")
if err = transport.Open(); err != nil {
return nil, err
}

protocol := thrift.NewTBinaryProtocolFactoryDefault()
client := hiveserver2.NewTCLIServiceClientFactory(transport, protocol)
s := hiveserver2.NewTOpenSessionReq()
s.ClientProtocol = 6
s.ClientProtocol = hiveserver2.TProtocolVersion_HIVE_CLI_SERVICE_PROTOCOL_V6
if cfg.User != "" {
s.Username = &cfg.User
if cfg.Passwd != "" {
Expand Down
10 changes: 10 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ func TestOpenConnection(t *testing.T) {
defer db.Close()
}

func TestOpenConnectionAgainstAuth(t *testing.T) {
db, _ := sql.Open("hive", "127.0.0.1:10000/churn?auth=PLAIN")
rows, err := db.Query("SELECT customerID, gender FROM train")
assert.EqualError(t, err, "Bad SASL negotiation status: 4 ()")
defer db.Close()
if err == nil {
defer rows.Close()
}
}

func TestQuery(t *testing.T) {
db, _ := sql.Open("hive", "127.0.0.1:10000/churn")
rows, err := db.Query("SELECT customerID, gender FROM train")
Expand Down
46 changes: 36 additions & 10 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@ type Config struct {
Passwd string
Addr string
DBName string
Auth string
}

var (
// Regexp syntax: https://github.com/google/re2/wiki/Syntax
reDSN = regexp.MustCompile(`(.+@)?([^@]+)`)
reDSN = regexp.MustCompile(`(.+@)?([^@|^?]+)\\?(.*)`)
reUserPasswd = regexp.MustCompile(`([^:@]+)(:[^:@]+)?@`)
reArguments = regexp.MustCompile(`(\w+)=(\w+)`)
)

// ParseDSN requires DSN names in the format [user[:password]@]addr/dbname.
func ParseDSN(dsn string) (*Config, error) {
// Please read https://play.golang.org/p/_CSLvl1AxOX before code review.
sub := reDSN.FindStringSubmatch(dsn)
if len(sub) != 3 {
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname]", dsn)
if len(sub) != 4 {
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname][?auth=AUTH_MECHANISM]", dsn)
}
addr := ""
dbname := ""
Expand All @@ -35,21 +37,45 @@ func ParseDSN(dsn string) (*Config, error) {
} else {
addr = sub[2]
}
user := ""
passwd := ""
auth := "NOSASL"
up := reUserPasswd.FindStringSubmatch(sub[1])
if len(up) == 3 {
user = up[1]
if len(up[2]) > 0 {
return &Config{User: up[1], Passwd: up[2][1:], Addr: addr, DBName: dbname}, nil
passwd = up[2][1:]
}
return &Config{User: up[1], Addr: addr, DBName: dbname}, nil
}
return &Config{Addr: addr, DBName: dbname}, nil

args := reArguments.FindAllStringSubmatch(sub[3], -1)
if len(args) > 1 {
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname][?auth=AUTH_MECHANISM]", dsn)
}

if len(args) == 1 {
if args[0][1] != "auth" {
return nil, fmt.Errorf("The DSN %s doesn't match [user[:password]@]addr[/dbname][?auth=AUTH_MECHANISM]", dsn)
}
auth = args[0][2]
}
return &Config{
User: user,
Passwd: passwd,
Addr: addr,
DBName: dbname,
Auth: auth,
}, nil
}

// FormatDSN outputs a string in the format "user:password@address"
// FormatDSN outputs a string in the format "user:password@address?auth=xxx"
func (cfg *Config) FormatDSN() string {
dsn := fmt.Sprintf("%s:%s@%s", cfg.User, cfg.Passwd, cfg.Addr)
if len(cfg.DBName) > 0 {
return fmt.Sprintf("%s:%s@%s/%s", cfg.User, cfg.Passwd, cfg.Addr, cfg.DBName)
} else {
return fmt.Sprintf("%s:%s@%s", cfg.User, cfg.Passwd, cfg.Addr)
dsn = fmt.Sprintf("%s/%s", dsn, cfg.DBName)
}
if len(cfg.Auth) > 0 {
dsn = fmt.Sprintf("%s?auth=%s", dsn, cfg.Auth)
}
return dsn
}
22 changes: 20 additions & 2 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@ import (
"github.com/stretchr/testify/assert"
)

func TestParseDSNWithAuth(t *testing.T) {
cfg, e := ParseDSN("root:root@127.0.0.1/mnist?auth=PLAIN")
assert.Nil(t, e)
assert.Equal(t, cfg.User, "root")
assert.Equal(t, cfg.Passwd, "root")
assert.Equal(t, cfg.Addr, "127.0.0.1")
assert.Equal(t, cfg.DBName, "mnist")
assert.Equal(t, cfg.Auth, "PLAIN")

cfg, e = ParseDSN("root@127.0.0.1/mnist")
assert.Nil(t, e)
assert.Equal(t, cfg.User, "root")
assert.Equal(t, cfg.Passwd, "")
assert.Equal(t, cfg.Addr, "127.0.0.1")
assert.Equal(t, cfg.DBName, "mnist")
assert.Equal(t, cfg.Auth, "NOSASL")
}

func TestParseDSNWithDBName(t *testing.T) {
cfg, e := ParseDSN("root:root@127.0.0.1/mnist")
assert.Nil(t, e)
Expand Down Expand Up @@ -50,7 +68,7 @@ func TestParseDSNWithoutDBName(t *testing.T) {
}

func TestFormatDSNWithDBName(t *testing.T) {
ds := "user:passwd@127.0.0.1/mnist"
ds := "user:passwd@127.0.0.1/mnist?auth=NOSASL"
cfg, e := ParseDSN(ds)
assert.Nil(t, e)

Expand All @@ -59,7 +77,7 @@ func TestFormatDSNWithDBName(t *testing.T) {
}

func TestFormatDSNWithoutDBName(t *testing.T) {
ds := "user:passwd@127.0.0.1"
ds := "user:passwd@127.0.0.1?auth=NOSASL"
cfg, e := ParseDSN(ds)
assert.Nil(t, e)

Expand Down

0 comments on commit abe07b6

Please sign in to comment.