From 905334b570f4a25d6ebc369bc29e4dc6025899aa Mon Sep 17 00:00:00 2001 From: Brent Rowland Date: Tue, 3 Feb 2015 17:14:28 -0800 Subject: [PATCH] Add ROLE support. --- connection.go | 8 ++++---- driver_test.go | 39 +++++++++++++++++++++++++++++++++++++++ utils.go | 3 ++- utils_test.go | 14 ++++++++------ wireprotocol.go | 9 ++++++--- 5 files changed, 59 insertions(+), 14 deletions(-) diff --git a/connection.go b/connection.go index 05f52407..9a61dbdc 100644 --- a/connection.go +++ b/connection.go @@ -84,7 +84,7 @@ func (fc *firebirdsqlConn) Query(query string, args []driver.Value) (rows driver } func newFirebirdsqlConn(dsn string) (fc *firebirdsqlConn, err error) { - addr, dbName, user, password, err := parseDSN(dsn) + addr, dbName, user, password, role, err := parseDSN(dsn) wp, err := newWireProtocol(addr) if err != nil { return @@ -96,7 +96,7 @@ func newFirebirdsqlConn(dsn string) (fc *firebirdsqlConn, err error) { if err != nil { return } - wp.opAttach(dbName, user, password) + wp.opAttach(dbName, user, password, role) wp.dbHandle, _, _, err = wp.opResponse() if err != nil { return @@ -118,7 +118,7 @@ func newFirebirdsqlConn(dsn string) (fc *firebirdsqlConn, err error) { func createFirebirdsqlConn(dsn string) (fc *firebirdsqlConn, err error) { // Create Database - addr, dbName, user, password, err := parseDSN(dsn) + addr, dbName, user, password, role, err := parseDSN(dsn) wp, err := newWireProtocol(addr) if err != nil { return @@ -131,7 +131,7 @@ func createFirebirdsqlConn(dsn string) (fc *firebirdsqlConn, err error) { if err != nil { return } - wp.opCreate(dbName, user, password) + wp.opCreate(dbName, user, password, role) wp.dbHandle, _, _, err = wp.opResponse() fc = new(firebirdsqlConn) diff --git a/driver_test.go b/driver_test.go index beb21757..d3fc73a7 100644 --- a/driver_test.go +++ b/driver_test.go @@ -271,6 +271,45 @@ func TestError(t *testing.T) { } } +func TestRole(t *testing.T) { + conn1, err := sql.Open("firebirdsql_createdb", "sysdba:masterkey@localhost:3050/tmp/go_test_role.fdb") + if err != nil { + t.Fatalf("Error creating: %v", err) + } + conn1.Exec("CREATE TABLE test_role (f1 integer)") + conn1.Exec("INSERT INTO test_role (f1) values (1)") + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + conn1.Exec("CREATE ROLE DRIVERROLE") + if err != nil { + t.Fatalf("Error creating role: %v", err) + } + conn1.Exec("GRANT DRIVERROLE TO DRIVERTEST") + if err != nil { + t.Fatalf("Error creating role: %v", err) + } + conn1.Exec("GRANT SELECT ON test_role TO DRIVERROLE") + if err != nil { + t.Fatalf("Error granting right to role: %v", err) + } + conn1.Close() + + conn2, err := sql.Open("firebirdsql", "drivertest:driverpw:driverrole@localhost:3050/tmp/go_test_role.fdb") + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + + rows, err := conn2.Query("SELECT f1 FROM test_role") + defer conn2.Close() + if err != nil { + t.Fatalf("Error Query: %v", err) + } + + for rows.Next() { + } +} + /* func TestFB3(t *testing.T) { conn, err := sql.Open("firebirdsql_createdb", "sysdba:masterkey@localhost:3050/tmp/go_test_fb3.fdb") diff --git a/utils.go b/utils.go index 02c39ed6..2aa8fbd4 100644 --- a/utils.go +++ b/utils.go @@ -309,9 +309,10 @@ func split1(src string, delm string) (string, string) { return src, "" } -func parseDSN(dsn string) (addr string, dbName string, user string, passwd string, err error) { +func parseDSN(dsn string) (addr string, dbName string, user string, passwd string, role string, err error) { s1, s2 := split1(dsn, "@") user, passwd = split1(s1, ":") + passwd, role = split1(passwd, ":") addr, dbName = split1(s2, "/") if !strings.ContainsRune(addr, ':') { addr += ":3050" diff --git a/utils_test.go b/utils_test.go index 2f91ff1d..98c3b44c 100644 --- a/utils_test.go +++ b/utils_test.go @@ -35,16 +35,18 @@ func TestDSNParse(t *testing.T) { dbName string user string passwd string + role string }{ - {"user:password@localhost:3000/dbname", "localhost:3000", "dbname", "user", "password"}, - {"user:password@localhost/dbname", "localhost:3050", "dbname", "user", "password"}, - {"user:password@localhost/dir/dbname", "localhost:3050", "/dir/dbname", "user", "password"}, - {"user:password@localhost/c:\\fbdata\\database.fdb", "localhost:3050", "c:\\fbdata\\database.fdb", "user", "password"}, + {"user:password@localhost:3000/dbname", "localhost:3000", "dbname", "user", "password", ""}, + {"user:password@localhost/dbname", "localhost:3050", "dbname", "user", "password", ""}, + {"user:password@localhost/dir/dbname", "localhost:3050", "/dir/dbname", "user", "password", ""}, + {"user:password@localhost/c:\\fbdata\\database.fdb", "localhost:3050", "c:\\fbdata\\database.fdb", "user", "password", ""}, + {"user:password:role@localhost/dbname", "localhost:3050", "dbname", "user", "password", "role"}, } for _, d := range testDSNs { - addr, dbName, user, passwd, err := parseDSN(d.dsn) - if addr != d.addr || dbName != d.dbName || user != d.user || passwd != d.passwd { + addr, dbName, user, passwd, role, err := parseDSN(d.dsn) + if addr != d.addr || dbName != d.dbName || user != d.user || passwd != d.passwd || role != d.role { err = errors.New("parse DSN fail") } if err != nil { diff --git a/wireprotocol.go b/wireprotocol.go index 1be54d4d..1b0312d2 100644 --- a/wireprotocol.go +++ b/wireprotocol.go @@ -474,7 +474,7 @@ func (p *wireProtocol) opConnect(dbName string, user string, password string, cl p.sendPackets() } -func (p *wireProtocol) opCreate(dbName string, user string, password string) { +func (p *wireProtocol) opCreate(dbName string, user string, password string, role string) { debugPrint(p, "opCreate") var page_size int32 page_size = 4096 @@ -482,12 +482,14 @@ func (p *wireProtocol) opCreate(dbName string, user string, password string) { encode := bytes.NewBufferString("UTF8").Bytes() userBytes := bytes.NewBufferString(strings.ToUpper(user)).Bytes() passwordBytes := bytes.NewBufferString(password).Bytes() + roleBytes := []byte(role) dpb := bytes.Join([][]byte{ []byte{1}, []byte{68, byte(len(encode))}, encode, []byte{48, byte(len(encode))}, encode, []byte{28, byte(len(userBytes))}, userBytes, []byte{29, byte(len(passwordBytes))}, passwordBytes, + []byte{60, byte(len(roleBytes))}, roleBytes, []byte{63, 4}, int32_to_bytes(3), []byte{24, 4}, bint32_to_bytes(1), []byte{54, 4}, bint32_to_bytes(1), @@ -606,17 +608,18 @@ func (p *wireProtocol) opAccept(user string, password string, clientPublic *big. return } -func (p *wireProtocol) opAttach(dbName string, user string, password string) { +func (p *wireProtocol) opAttach(dbName string, user string, password string, role string) { debugPrint(p, "opAttach") encode := bytes.NewBufferString("UTF8").Bytes() userBytes := bytes.NewBufferString(strings.ToUpper(user)).Bytes() passwordBytes := bytes.NewBufferString(password).Bytes() - + roleBytes := []byte(role) dbp := bytes.Join([][]byte{ []byte{1}, []byte{48, byte(len(encode))}, encode, []byte{28, byte(len(userBytes))}, userBytes, []byte{29, byte(len(passwordBytes))}, passwordBytes, + []byte{60, byte(len(roleBytes))}, roleBytes, }, nil) p.packInt(op_attach) p.packInt(0) // Database Object ID