diff --git a/conn.go b/conn.go index 20050ec..594e9be 100644 --- a/conn.go +++ b/conn.go @@ -13,15 +13,29 @@ import ( "os" ) -// LogLevel is used to control which messages are written to the log. +// LogLevel is used to control what written to the log. type LogLevel int const ( + // Log nothing. LogNothing LogLevel = iota + + // Log fatal errors. LogFatal + + // Log all errors. LogError + + // Log errors and warnings. LogWarning + + // Log errors, warnings and sent commands. + LogCommand + + // Log errors, warnings, sent commands and additional debug info. LogDebug + + // Log everything. LogVerbose ) @@ -60,25 +74,45 @@ func (s ConnStatus) String() string { return "Unknown" } +// IsolationLevel represents the isolation level of a transaction. +type IsolationLevel int + +const ( + ReadCommittedIsolation IsolationLevel = iota + SerializableIsolation +) + +func (il IsolationLevel) String() string { + switch il { + case ReadCommittedIsolation: + return "Read Committed" + + case SerializableIsolation: + return "Serializable" + } + + return "Unknown" +} + // TransactionStatus represents the transaction status of a connection. type TransactionStatus byte const ( - NotInTransaction TransactionStatus = 'I' - InTransaction TransactionStatus = 'T' - ErrorInTransaction TransactionStatus = 'E' + NotInTransaction TransactionStatus = 'I' + InTransaction TransactionStatus = 'T' + InFailedTransaction TransactionStatus = 'E' ) func (s TransactionStatus) String() string { switch s { case NotInTransaction: - return "Not in transaction" + return "Not In Transaction" case InTransaction: - return "In transaction" + return "In Transaction" - case ErrorInTransaction: - return "Error in transaction" + case InFailedTransaction: + return "In Failed Transaction" } return "Unknown" @@ -98,6 +132,7 @@ type Conn struct { runtimeParameters map[string]string nextStatementId uint64 nextPortalId uint64 + nextSavepointId uint64 transactionStatus TransactionStatus } @@ -308,24 +343,97 @@ func (conn *Conn) TransactionStatus() TransactionStatus { return conn.transactionStatus } -// WithTransaction starts a transaction, then calls function f. -// If f returns an error or panicks, the transaction is rolled back, -// otherwise it is committed. -func (conn *Conn) WithTransaction(f func() os.Error) (err os.Error) { +// WithTransaction starts a new transaction, if none is in progress, then +// calls f. If f returns an error or panicks, the transaction is rolled back, +// otherwise it is committed. If the connection is in a failed transaction when +// calling WithTransaction, this function immediately returns with an error, +// without calling f. In case of an active transaction without error, +// WithTransaction just calls f. +func (conn *Conn) WithTransaction(isolation IsolationLevel, f func() os.Error) (err os.Error) { if conn.LogLevel >= LogDebug { defer conn.logExit(conn.logEnter("*Conn.WithTransaction")) } + oldStatus := conn.transactionStatus + + if oldStatus == InFailedTransaction { + return conn.logAndConvertPanic("error in transaction") + } + defer func() { if x := recover(); x != nil { err = conn.logAndConvertPanic(x) } - if err != nil { + if err == nil && conn.transactionStatus == InFailedTransaction { + err = conn.logAndConvertPanic("error in transaction") + } + if err != nil && oldStatus == NotInTransaction { conn.Execute("ROLLBACK;") } }() - _, err = conn.Execute("BEGIN;") + if oldStatus == NotInTransaction { + var isol string + if isolation == SerializableIsolation { + isol = "SERIALIZABLE" + } else { + isol = "READ COMMITTED" + } + cmd := fmt.Sprintf("BEGIN; SET TRANSACTION ISOLATION LEVEL %s;", isol) + _, err = conn.Execute(cmd) + if err != nil { + panic(err) + } + } + + err = f() + if err != nil { + panic(err) + } + + if oldStatus == NotInTransaction && conn.transactionStatus == InTransaction { + _, err = conn.Execute("COMMIT;") + } + return +} + +// WithSavepoint creates a transaction savepoint, if the connection is in an +// active transaction without errors, then calls f. If f returns an error or +// panicks, the transaction is rolled back to the savepoint. If the connection +// is in a failed transaction when calling WithSavepoint, this function +// immediately returns with an error, without calling f. If no transaction is in +// progress, instead of creating a savepoint, a new transaction is started. +func (conn *Conn) WithSavepoint(isolation IsolationLevel, f func() os.Error) (err os.Error) { + if conn.LogLevel >= LogDebug { + defer conn.logExit(conn.logEnter("*Conn.WithSavepoint")) + } + + oldStatus := conn.transactionStatus + + switch oldStatus { + case InFailedTransaction: + return conn.logAndConvertPanic("error in transaction") + + case NotInTransaction: + return conn.WithTransaction(isolation, f) + } + + savepointName := fmt.Sprintf("sp%d", conn.nextSavepointId) + conn.nextSavepointId++ + + defer func() { + if x := recover(); x != nil { + err = conn.logAndConvertPanic(x) + } + if err == nil && conn.transactionStatus == InFailedTransaction { + err = conn.logAndConvertPanic("error in transaction") + } + if err != nil { + conn.Execute(fmt.Sprintf("ROLLBACK TO %s;", savepointName)) + } + }() + + _, err = conn.Execute(fmt.Sprintf("SAVEPOINT %s;", savepointName)) if err != nil { panic(err) } @@ -335,6 +443,5 @@ func (conn *Conn) WithTransaction(f func() os.Error) (err os.Error) { panic(err) } - _, err = conn.Execute("COMMIT;") return } diff --git a/conn_read.go b/conn_read.go index 600664e..b7d25bd 100644 --- a/conn_read.go +++ b/conn_read.go @@ -168,14 +168,7 @@ func (conn *Conn) readCommandComplete(res *ResultSet) { if res != nil { parts := strings.Split(tag, " ", -1) - rowsAffected, err := strconv.Atoi64(parts[len(parts)-1]) - if err != nil { - if conn.LogLevel >= LogWarning { - conn.log(LogWarning, "failed to retrieve affected row count") - } - } - - res.rowsAffected = rowsAffected + res.rowsAffected, _ = strconv.Atoi64(parts[len(parts)-1]) res.currentResultComplete = true } } diff --git a/conn_write.go b/conn_write.go index 4ebeac3..9ea11d8 100644 --- a/conn_write.go +++ b/conn_write.go @@ -246,8 +246,10 @@ func (conn *Conn) writeExecute(stmt *Statement) { func (conn *Conn) writeParse(stmt *Statement) { if conn.LogLevel >= LogDebug { defer conn.logExit(conn.logEnter("*Conn.writeParse")) + } - conn.log(LogDebug, fmt.Sprintf("stmt.ActualCommand: '%s'", stmt.ActualCommand())) + if conn.LogLevel >= LogCommand { + conn.log(LogCommand, fmt.Sprintf("stmt.ActualCommand: '%s'", stmt.ActualCommand())) } msgLen := int32(4 + @@ -285,8 +287,10 @@ func (conn *Conn) writePasswordMessage(password string) { func (conn *Conn) writeQuery(command string) { if conn.LogLevel >= LogDebug { defer conn.logExit(conn.logEnter("*Conn.writeQuery")) + } - conn.log(LogDebug, fmt.Sprintf("command: '%s'", command)) + if conn.LogLevel >= LogCommand { + conn.log(LogCommand, fmt.Sprintf("command: '%s'", command)) } conn.writeFrontendMessageCode(_Query) diff --git a/pgsql_test.go b/pgsql_test.go index f53028f..0d6a4fd 100644 --- a/pgsql_test.go +++ b/pgsql_test.go @@ -7,6 +7,7 @@ package pgsql import ( "fmt" "math" + "os" "strings" "testing" "time" @@ -608,3 +609,103 @@ func Test_Insert_Time(t *testing.T) { }) } } + +func Test_Conn_WithSavepoint(t *testing.T) { + withConn(t, func(conn *Conn) { + conn.Execute("DROP TABLE _gopgsql_test_account;") + + _, err := conn.Execute(` + CREATE TABLE _gopgsql_test_account + ( + name VARCHAR(20) PRIMARY KEY, + balance REAL NOT NULL + ); + INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Alice', 100.0); + INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Bob', 0.0); + INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Wally', 0.0); + `) + if err != nil { + t.Error("failed to create table:", err) + return + } + defer func() { + conn.Execute("DROP TABLE _gopgsql_test_account;") + }() + + err = conn.WithTransaction(ReadCommittedIsolation, func() (err os.Error) { + _, err = conn.Execute(` + UPDATE _gopgsql_test_account + SET balance = balance - 100.0 + WHERE name = 'Alice';`) + if err != nil { + t.Error("failed to execute update:", err) + return + } + + err = conn.WithSavepoint(ReadCommittedIsolation, func() (err os.Error) { + _, err = conn.Execute(` + UPDATE _gopgsql_test_account + SET balance = balance + 100.0 + WHERE name = 'Bob';`) + if err != nil { + t.Error("failed to execute update:", err) + return + } + + err = os.NewError("wrong credit account") + + return + }) + + _, err = conn.Execute(` + UPDATE _gopgsql_test_account + SET balance = balance + 100.0 + WHERE name = 'Wally';`) + if err != nil { + t.Error("failed to execute update:", err) + return + } + + return + }) + + var rs *ResultSet + rs, err = conn.Query("SELECT name, balance FROM _gopgsql_test_account;") + if err != nil { + t.Error("failed to query:", err) + return + } + defer rs.Close() + + have := make(map[string]float64) + want := map[string]float64{ + "Alice": 0, + "Bob": 0, + "Wally": 100, + } + var name string + var balance float64 + var fetched bool + + for { + fetched, err = rs.ScanNext(&name, &balance) + if err != nil { + t.Error("failed to scan next:", err) + return + } + if !fetched { + break + } + + have[name] = balance + } + + for name, haveBalance := range have { + wantBalance := want[name] + + if math.Fabs(haveBalance-wantBalance) > 0.000001 { + t.Errorf("name: %s have: %f, but want: %f", name, haveBalance, wantBalance) + } + } + }) +}