Skip to content

Commit

Permalink
Initial support for date/time types; expose runtime parameters; fix e…
Browse files Browse the repository at this point in the history
…xecution of prepared statements that return no data; maybe some more fixes ;)
  • Loading branch information
lxn committed Aug 14, 2010
1 parent d7c5a79 commit 3072a33
Show file tree
Hide file tree
Showing 13 changed files with 543 additions and 90 deletions.
2 changes: 1 addition & 1 deletion README
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ go-pgsql is currently missing support for some features, including:

- authentication types other than MD5
- SSL encrypted sessions
- some data types like date/time, bytea, ...
- some data types like bytea, ...
- canceling commands/queries
- bulk copy
...
Expand Down
16 changes: 15 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type Conn struct {
backendPID int32
backendSecretKey int32
onErrorDontRequireReadyForQuery bool
runtimeParameters map[string]string
}

// Connect establishes a database connection.
Expand All @@ -95,7 +96,7 @@ func Connect(parameters *ConnParams) (conn *Conn, err os.Error) {
newConn.params = params

if params.Host == "" {
params.Host = "127.0.0.1"
params.Host = "localhost"
}
if params.Port == 0 {
params.Port = 5432
Expand All @@ -116,6 +117,8 @@ func Connect(parameters *ConnParams) (conn *Conn, err os.Error) {
newConn.reader = bufio.NewReader(tcpConn)
newConn.writer = bufio.NewWriter(tcpConn)

newConn.runtimeParameters = make(map[string]string)

newConn.writeStartup()

newConn.readBackendMessages(nil)
Expand Down Expand Up @@ -231,6 +234,17 @@ func (conn *Conn) Query(command string) (res *ResultSet, err os.Error) {
return
}

// RuntimeParameter returns the value of the specified runtime parameter.
// If the value was successfully retrieved, ok is true, otherwise false.
func (conn *Conn) RuntimeParameter(name string) (value string, ok bool) {
if conn.LogLevel >= LogVerbose {
defer conn.logExit(conn.logEnter("*Conn.RuntimeParameter"))
}

value, ok = conn.runtimeParameters[name]
return
}

// Scan executes the command and scans the fields of the first row
// in the ResultSet, trying to store field values into the specified
// arguments. The arguments must be of pointer types. If a row has
Expand Down
4 changes: 2 additions & 2 deletions conn_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ func (conn *Conn) logError(level LogLevel, err os.Error) {
}

func (conn *Conn) logEnter(funcName string) string {
conn.log(LogDebug, "entering: ", "pgsqconn."+funcName)
conn.log(LogDebug, "entering: ", "pgsql."+funcName)
return funcName
}

func (conn *Conn) logExit(funcName string) {
conn.log(LogDebug, "exiting: ", "pgsqconn."+funcName)
conn.log(LogDebug, "exiting: ", "pgsql."+funcName)
}

func (conn *Conn) logAndConvertPanic(x interface{}) (err os.Error) {
Expand Down
138 changes: 76 additions & 62 deletions conn_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,68 +59,6 @@ func (conn *Conn) readString() string {
return string(b[0 : len(b)-1])
}

func (conn *Conn) readDataRow(res *ResultSet) {
// Just eat message length.
conn.readInt32()

fieldCount := conn.readInt16()

var ord int16
for ord = 0; ord < fieldCount; ord++ {
valLen := conn.readInt32()

var val []byte

if valLen == -1 {
val = nil
} else {
val = make([]byte, valLen)
conn.read(val)
}

res.values[ord] = val
}
}

func (conn *Conn) readRowDescription(res *ResultSet) {
// Just eat message length.
conn.readInt32()

fieldCount := conn.readInt16()

res.fields = make([]field, fieldCount)
res.values = make([][]byte, fieldCount)

var ord int16
for ord = 0; ord < fieldCount; ord++ {
res.fields[ord].name = conn.readString()

// Just eat table OID.
conn.readInt32()

// Just eat field OID.
conn.readInt16()

// Just eat field data type OID.
conn.readInt32()

// Just eat field size.
conn.readInt16()

// Just eat field type modifier.
conn.readInt32()

format := fieldFormat(conn.readInt16())
switch format {
case textFormat:
case binaryFormat:
default:
panic("unsupported field format")
}
res.fields[ord].format = format
}
}

func (conn *Conn) readAuthenticationRequest() {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.readAuthenticationRequest"))
Expand Down Expand Up @@ -242,6 +180,29 @@ func (conn *Conn) readCommandComplete(res *ResultSet) {
}
}

func (conn *Conn) readDataRow(res *ResultSet) {
// Just eat message length.
conn.readInt32()

fieldCount := conn.readInt16()

var ord int16
for ord = 0; ord < fieldCount; ord++ {
valLen := conn.readInt32()

var val []byte

if valLen == -1 {
val = nil
} else {
val = make([]byte, valLen)
conn.read(val)
}

res.values[ord] = val
}
}

func (conn *Conn) readEmptyQueryResponse() {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.readEmptyQueryResponse"))
Expand Down Expand Up @@ -323,6 +284,15 @@ func (conn *Conn) readErrorOrNoticeResponse(isError bool) {
}
}

func (conn *Conn) readNoData() {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.readNoData"))
}

// Just eat message length.
conn.readInt32()
}

func (conn *Conn) readParameterStatus() {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.readParameterStatus"))
Expand All @@ -337,6 +307,8 @@ func (conn *Conn) readParameterStatus() {
if conn.LogLevel >= LogDebug {
conn.logf(LogDebug, "ParameterStatus: Name: '%s', Value: '%s'", name, value)
}

conn.runtimeParameters[name] = value
}

func (conn *Conn) readParseComplete() {
Expand Down Expand Up @@ -381,6 +353,44 @@ func (conn *Conn) readReadyForQuery(res *ResultSet) {
conn.state = readyState{}
}

func (conn *Conn) readRowDescription(res *ResultSet) {
// Just eat message length.
conn.readInt32()

fieldCount := conn.readInt16()

res.fields = make([]field, fieldCount)
res.values = make([][]byte, fieldCount)

var ord int16
for ord = 0; ord < fieldCount; ord++ {
res.fields[ord].name = conn.readString()

// Just eat table OID.
conn.readInt32()

// Just eat field OID.
conn.readInt16()

res.fields[ord].typeOID = conn.readInt32()

// Just eat field size.
conn.readInt16()

// Just eat field type modifier.
conn.readInt32()

format := fieldFormat(conn.readInt16())
switch format {
case textFormat:
case binaryFormat:
default:
panic("unsupported field format")
}
res.fields[ord].format = format
}
}

func (conn *Conn) readBackendMessages(res *ResultSet) {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.readBackendMessages"))
Expand Down Expand Up @@ -421,6 +431,10 @@ func (conn *Conn) readBackendMessages(res *ResultSet) {
case _ErrorResponse:
conn.readErrorOrNoticeResponse(true)

case _NoData:
conn.readNoData()
return

case _NoticeResponse:
conn.readErrorOrNoticeResponse(false)

Expand Down
42 changes: 41 additions & 1 deletion conn_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"math"
"strconv"
"time"
)

func (conn *Conn) flush() {
Expand Down Expand Up @@ -134,13 +135,40 @@ func (conn *Conn) writeBind(stmt *Statement) {
values[i] = strconv.Itoa(int(val))

case int64:
values[i] = strconv.Itoa64(val)
switch param.typ {
case Date:
values[i] = time.SecondsToUTC(val).Format("2006-01-02")

case Time, TimeTZ:
values[i] = time.SecondsToUTC(val).Format("15:04:05")

case Timestamp, TimestampTZ:
values[i] = time.SecondsToUTC(val).Format("2006-01-02 15:04:05")

default:
values[i] = strconv.Itoa64(val)
}

case nil:

case string:
values[i] = val

case *time.Time:
switch param.typ {
case Date:
values[i] = val.Format("2006-01-02")

case Time, TimeTZ:
values[i] = val.Format("15:04:05")

case Timestamp, TimestampTZ:
values[i] = val.Format("2006-01-02 15:04:05")

default:
panic("invalid use of *time.Time")
}

default:
panic("unsupported parameter type")
}
Expand Down Expand Up @@ -212,6 +240,12 @@ func (conn *Conn) writeExecute(stmt *Statement) {
}

func (conn *Conn) writeParse(stmt *Statement) {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.writeParse"))

conn.log(LogDebug, fmt.Sprintf("stmt.ActualCommand: '%s'", stmt.ActualCommand()))
}

msgLen := int32(4 +
len(stmt.name) + 1 +
len(stmt.actualCommand) + 1 +
Expand Down Expand Up @@ -245,6 +279,12 @@ func (conn *Conn) writePasswordMessage(password string) {
}

func (conn *Conn) writeQuery(command string) {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.writeQuery"))

conn.log(LogDebug, fmt.Sprintf("command: '%s'", command))
}

conn.writeFrontendMessageCode(_Query)
conn.writeInt32(int32(4 + len(command) + 1))
conn.writeString0(command)
Expand Down
1 change: 0 additions & 1 deletion examples/multipleselects.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ func main() {
pgsql.DefaultLogLevel = pgsql.LogError

params := &pgsql.ConnParams{
Host: "127.0.0.1",
Database: "testdatabase",
User: "testuser",
Password: "testpassword",
Expand Down
1 change: 0 additions & 1 deletion examples/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func main() {
pgsql.DefaultLogLevel = pgsql.LogError

params := &pgsql.ConnParams{
Host: "127.0.0.1",
Database: "testdatabase",
User: "testuser",
Password: "testpassword",
Expand Down
1 change: 0 additions & 1 deletion examples/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ func main() {
pgsql.DefaultLogLevel = pgsql.LogError

params := &pgsql.ConnParams{
Host: "127.0.0.1",
Database: "testdatabase",
User: "testuser",
Password: "testpassword",
Expand Down
Loading

0 comments on commit 3072a33

Please sign in to comment.