Skip to content

Commit

Permalink
Overhaul *Conn.WithTransaction; add *Conn.WithSavepoint; some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lxn committed Aug 15, 2010
1 parent b07b8b1 commit 316ffe1
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 25 deletions.
137 changes: 122 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,29 @@ import (
"os"
)

// LogLevel is used to control which messages are written to the log.
// LogLevel is used to control what written to the log.
type LogLevel int

const (
// Log nothing.
LogNothing LogLevel = iota

// Log fatal errors.
LogFatal

// Log all errors.
LogError

// Log errors and warnings.
LogWarning

// Log errors, warnings and sent commands.
LogCommand

// Log errors, warnings, sent commands and additional debug info.
LogDebug

// Log everything.
LogVerbose
)

Expand Down Expand Up @@ -60,25 +74,45 @@ func (s ConnStatus) String() string {
return "Unknown"
}

// IsolationLevel represents the isolation level of a transaction.
type IsolationLevel int

const (
ReadCommittedIsolation IsolationLevel = iota
SerializableIsolation
)

func (il IsolationLevel) String() string {
switch il {
case ReadCommittedIsolation:
return "Read Committed"

case SerializableIsolation:
return "Serializable"
}

return "Unknown"
}

// TransactionStatus represents the transaction status of a connection.
type TransactionStatus byte

const (
NotInTransaction TransactionStatus = 'I'
InTransaction TransactionStatus = 'T'
ErrorInTransaction TransactionStatus = 'E'
NotInTransaction TransactionStatus = 'I'
InTransaction TransactionStatus = 'T'
InFailedTransaction TransactionStatus = 'E'
)

func (s TransactionStatus) String() string {
switch s {
case NotInTransaction:
return "Not in transaction"
return "Not In Transaction"

case InTransaction:
return "In transaction"
return "In Transaction"

case ErrorInTransaction:
return "Error in transaction"
case InFailedTransaction:
return "In Failed Transaction"
}

return "Unknown"
Expand All @@ -98,6 +132,7 @@ type Conn struct {
runtimeParameters map[string]string
nextStatementId uint64
nextPortalId uint64
nextSavepointId uint64
transactionStatus TransactionStatus
}

Expand Down Expand Up @@ -308,24 +343,97 @@ func (conn *Conn) TransactionStatus() TransactionStatus {
return conn.transactionStatus
}

// WithTransaction starts a transaction, then calls function f.
// If f returns an error or panicks, the transaction is rolled back,
// otherwise it is committed.
func (conn *Conn) WithTransaction(f func() os.Error) (err os.Error) {
// WithTransaction starts a new transaction, if none is in progress, then
// calls f. If f returns an error or panicks, the transaction is rolled back,
// otherwise it is committed. If the connection is in a failed transaction when
// calling WithTransaction, this function immediately returns with an error,
// without calling f. In case of an active transaction without error,
// WithTransaction just calls f.
func (conn *Conn) WithTransaction(isolation IsolationLevel, f func() os.Error) (err os.Error) {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.WithTransaction"))
}

oldStatus := conn.transactionStatus

if oldStatus == InFailedTransaction {
return conn.logAndConvertPanic("error in transaction")
}

defer func() {
if x := recover(); x != nil {
err = conn.logAndConvertPanic(x)
}
if err != nil {
if err == nil && conn.transactionStatus == InFailedTransaction {
err = conn.logAndConvertPanic("error in transaction")
}
if err != nil && oldStatus == NotInTransaction {
conn.Execute("ROLLBACK;")
}
}()

_, err = conn.Execute("BEGIN;")
if oldStatus == NotInTransaction {
var isol string
if isolation == SerializableIsolation {
isol = "SERIALIZABLE"
} else {
isol = "READ COMMITTED"
}
cmd := fmt.Sprintf("BEGIN; SET TRANSACTION ISOLATION LEVEL %s;", isol)
_, err = conn.Execute(cmd)
if err != nil {
panic(err)
}
}

err = f()
if err != nil {
panic(err)
}

if oldStatus == NotInTransaction && conn.transactionStatus == InTransaction {
_, err = conn.Execute("COMMIT;")
}
return
}

// WithSavepoint creates a transaction savepoint, if the connection is in an
// active transaction without errors, then calls f. If f returns an error or
// panicks, the transaction is rolled back to the savepoint. If the connection
// is in a failed transaction when calling WithSavepoint, this function
// immediately returns with an error, without calling f. If no transaction is in
// progress, instead of creating a savepoint, a new transaction is started.
func (conn *Conn) WithSavepoint(isolation IsolationLevel, f func() os.Error) (err os.Error) {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.WithSavepoint"))
}

oldStatus := conn.transactionStatus

switch oldStatus {
case InFailedTransaction:
return conn.logAndConvertPanic("error in transaction")

case NotInTransaction:
return conn.WithTransaction(isolation, f)
}

savepointName := fmt.Sprintf("sp%d", conn.nextSavepointId)
conn.nextSavepointId++

defer func() {
if x := recover(); x != nil {
err = conn.logAndConvertPanic(x)
}
if err == nil && conn.transactionStatus == InFailedTransaction {
err = conn.logAndConvertPanic("error in transaction")
}
if err != nil {
conn.Execute(fmt.Sprintf("ROLLBACK TO %s;", savepointName))
}
}()

_, err = conn.Execute(fmt.Sprintf("SAVEPOINT %s;", savepointName))
if err != nil {
panic(err)
}
Expand All @@ -335,6 +443,5 @@ func (conn *Conn) WithTransaction(f func() os.Error) (err os.Error) {
panic(err)
}

_, err = conn.Execute("COMMIT;")
return
}
9 changes: 1 addition & 8 deletions conn_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,7 @@ func (conn *Conn) readCommandComplete(res *ResultSet) {
if res != nil {
parts := strings.Split(tag, " ", -1)

rowsAffected, err := strconv.Atoi64(parts[len(parts)-1])
if err != nil {
if conn.LogLevel >= LogWarning {
conn.log(LogWarning, "failed to retrieve affected row count")
}
}

res.rowsAffected = rowsAffected
res.rowsAffected, _ = strconv.Atoi64(parts[len(parts)-1])
res.currentResultComplete = true
}
}
Expand Down
8 changes: 6 additions & 2 deletions conn_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,10 @@ func (conn *Conn) writeExecute(stmt *Statement) {
func (conn *Conn) writeParse(stmt *Statement) {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.writeParse"))
}

conn.log(LogDebug, fmt.Sprintf("stmt.ActualCommand: '%s'", stmt.ActualCommand()))
if conn.LogLevel >= LogCommand {
conn.log(LogCommand, fmt.Sprintf("stmt.ActualCommand: '%s'", stmt.ActualCommand()))
}

msgLen := int32(4 +
Expand Down Expand Up @@ -285,8 +287,10 @@ func (conn *Conn) writePasswordMessage(password string) {
func (conn *Conn) writeQuery(command string) {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.writeQuery"))
}

conn.log(LogDebug, fmt.Sprintf("command: '%s'", command))
if conn.LogLevel >= LogCommand {
conn.log(LogCommand, fmt.Sprintf("command: '%s'", command))
}

conn.writeFrontendMessageCode(_Query)
Expand Down
101 changes: 101 additions & 0 deletions pgsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package pgsql
import (
"fmt"
"math"
"os"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -608,3 +609,103 @@ func Test_Insert_Time(t *testing.T) {
})
}
}

func Test_Conn_WithSavepoint(t *testing.T) {
withConn(t, func(conn *Conn) {
conn.Execute("DROP TABLE _gopgsql_test_account;")

_, err := conn.Execute(`
CREATE TABLE _gopgsql_test_account
(
name VARCHAR(20) PRIMARY KEY,
balance REAL NOT NULL
);
INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Alice', 100.0);
INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Bob', 0.0);
INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Wally', 0.0);
`)
if err != nil {
t.Error("failed to create table:", err)
return
}
defer func() {
conn.Execute("DROP TABLE _gopgsql_test_account;")
}()

err = conn.WithTransaction(ReadCommittedIsolation, func() (err os.Error) {
_, err = conn.Execute(`
UPDATE _gopgsql_test_account
SET balance = balance - 100.0
WHERE name = 'Alice';`)
if err != nil {
t.Error("failed to execute update:", err)
return
}

err = conn.WithSavepoint(ReadCommittedIsolation, func() (err os.Error) {
_, err = conn.Execute(`
UPDATE _gopgsql_test_account
SET balance = balance + 100.0
WHERE name = 'Bob';`)
if err != nil {
t.Error("failed to execute update:", err)
return
}

err = os.NewError("wrong credit account")

return
})

_, err = conn.Execute(`
UPDATE _gopgsql_test_account
SET balance = balance + 100.0
WHERE name = 'Wally';`)
if err != nil {
t.Error("failed to execute update:", err)
return
}

return
})

var rs *ResultSet
rs, err = conn.Query("SELECT name, balance FROM _gopgsql_test_account;")
if err != nil {
t.Error("failed to query:", err)
return
}
defer rs.Close()

have := make(map[string]float64)
want := map[string]float64{
"Alice": 0,
"Bob": 0,
"Wally": 100,
}
var name string
var balance float64
var fetched bool

for {
fetched, err = rs.ScanNext(&name, &balance)
if err != nil {
t.Error("failed to scan next:", err)
return
}
if !fetched {
break
}

have[name] = balance
}

for name, haveBalance := range have {
wantBalance := want[name]

if math.Fabs(haveBalance-wantBalance) > 0.000001 {
t.Errorf("name: %s have: %f, but want: %f", name, haveBalance, wantBalance)
}
}
})
}

0 comments on commit 316ffe1

Please sign in to comment.