Skip to content

Commit ce3bac3

Browse files
author
Harald Nordgård-Hansen
committed
Add support for nested transactions.
Merge the TransactionExecutor interface into QueryExecutor, and implement the neccessary support so that any QueryExecutor can start a new transaction. If the QueryExecutor is already a transaction, it will create a new savepoint in the database.
1 parent abb7f23 commit ce3bac3

File tree

6 files changed

+131
-39
lines changed

6 files changed

+131
-39
lines changed

src/main/java/com/github/pgasync/Db.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
*
66
* @author Antti Laisi
77
*/
8-
public interface Db extends QueryExecutor, TransactionExecutor, Listenable, AutoCloseable {
8+
public interface Db extends QueryExecutor, Listenable, AutoCloseable {
99

1010
/**
1111
* Closes the pool, blocks the calling thread until connections are closed.

src/main/java/com/github/pgasync/QueryExecutor.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
*/
1313
public interface QueryExecutor {
1414

15+
/**
16+
* Begins a transaction.
17+
*/
18+
Observable<Transaction> begin();
19+
1520
/**
1621
* Executes an anonymous prepared statement. Uses native PostgreSQL syntax with $arg instead of ?
1722
* to mark parameters. Supported parameter types are String, Character, Number, Time, Date, Timestamp
@@ -34,6 +39,16 @@ public interface QueryExecutor {
3439
*/
3540
Observable<ResultSet> querySet(String sql, Object... params);
3641

42+
/**
43+
* Begins a transaction.
44+
*
45+
* @param onTransaction Called when transaction is successfully started.
46+
* @param onError Called on exception thrown
47+
*/
48+
default void begin(Consumer<Transaction> onTransaction, Consumer<Throwable> onError) {
49+
begin().subscribe(onTransaction::accept, onError::accept);
50+
}
51+
3752
/**
3853
* Executes a simple query.
3954
*

src/main/java/com/github/pgasync/TransactionExecutor.java

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/main/java/com/github/pgasync/impl/PgConnection.java

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,7 @@
1414

1515
package com.github.pgasync.impl;
1616

17-
import com.github.pgasync.Connection;
18-
import com.github.pgasync.ResultSet;
19-
import com.github.pgasync.Row;
20-
import com.github.pgasync.Transaction;
21-
import com.github.pgasync.impl.conversion.DataConverter;
22-
import com.github.pgasync.impl.message.*;
23-
import rx.Observable;
24-
import rx.Subscriber;
25-
import rx.observers.Subscribers;
17+
import static com.github.pgasync.impl.message.RowDescription.ColumnDescription;
2618

2719
import java.util.ArrayList;
2820
import java.util.HashMap;
@@ -32,7 +24,27 @@
3224
import java.util.concurrent.TimeUnit;
3325
import java.util.logging.Logger;
3426

35-
import static com.github.pgasync.impl.message.RowDescription.ColumnDescription;
27+
import com.github.pgasync.Connection;
28+
import com.github.pgasync.ResultSet;
29+
import com.github.pgasync.Row;
30+
import com.github.pgasync.Transaction;
31+
import com.github.pgasync.impl.conversion.DataConverter;
32+
import com.github.pgasync.impl.message.Authentication;
33+
import com.github.pgasync.impl.message.Bind;
34+
import com.github.pgasync.impl.message.CommandComplete;
35+
import com.github.pgasync.impl.message.DataRow;
36+
import com.github.pgasync.impl.message.ExtendedQuery;
37+
import com.github.pgasync.impl.message.Message;
38+
import com.github.pgasync.impl.message.Parse;
39+
import com.github.pgasync.impl.message.PasswordMessage;
40+
import com.github.pgasync.impl.message.Query;
41+
import com.github.pgasync.impl.message.ReadyForQuery;
42+
import com.github.pgasync.impl.message.RowDescription;
43+
import com.github.pgasync.impl.message.StartupMessage;
44+
45+
import rx.Observable;
46+
import rx.Subscriber;
47+
import rx.observers.Subscribers;
3648

3749
/**
3850
* A connection to PostgreSQL backed. The postmaster forks a backend process for
@@ -184,6 +196,10 @@ static Map<String,PgColumn> getColumns(ColumnDescription[] descriptions) {
184196
*/
185197
class PgConnectionTransaction implements Transaction {
186198

199+
@Override
200+
public Observable<Transaction> begin() {
201+
return querySet("SAVEPOINT sp_1").map(rs -> new PgConnectionNestedTransaction(1));
202+
}
187203
@Override
188204
public Observable<Void> commit() {
189205
return PgConnection.this.querySet("COMMIT")
@@ -211,4 +227,30 @@ <T> Observable<T> doRollback(Throwable t) {
211227
}
212228
}
213229

230+
/**
231+
* Nested Transaction using savepoints.
232+
*/
233+
class PgConnectionNestedTransaction extends PgConnectionTransaction {
234+
235+
final int depth;
236+
237+
PgConnectionNestedTransaction(int depth) {
238+
this.depth = depth;
239+
}
240+
@Override
241+
public Observable<Transaction> begin() {
242+
return querySet("SAVEPOINT sp_" + (depth+1))
243+
.map(rs -> new PgConnectionNestedTransaction(depth+1));
244+
}
245+
@Override
246+
public Observable<Void> commit() {
247+
return PgConnection.this.querySet("RELEASE SAVEPOINT sp_" + depth)
248+
.map(rs -> (Void) null);
249+
}
250+
@Override
251+
public Observable<Void> rollback() {
252+
return PgConnection.this.querySet("ROLLBACK TO SAVEPOINT sp_" + depth)
253+
.map(rs -> (Void) null);
254+
}
255+
}
214256
}

src/main/java/com/github/pgasync/impl/PgConnectionPool.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ class ReleasingTransaction implements Transaction {
252252
this.transaction = transaction;
253253
}
254254

255+
@Override
256+
public Observable<Transaction> begin() {
257+
// Nested transactions should not release things automatically.
258+
return transaction.begin();
259+
}
260+
255261
@Override
256262
public Observable<Void> rollback() {
257263
return transaction.rollback()

src/test/java/com/github/pgasync/impl/TransactionTest.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,61 @@ public void shouldInvalidateTxConnAfterError() throws Exception {
142142
assertEquals(0, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 22").size());
143143
}
144144

145+
@Test
146+
public void shouldSupportNestedTransactions() throws Exception {
147+
CountDownLatch sync = new CountDownLatch(1);
148+
149+
dbr.db().begin((transaction) ->
150+
transaction.begin((nested) ->
151+
nested.query("INSERT INTO TX_TEST(ID) VALUES(19)", result -> {
152+
assertEquals(1, result.updatedRows());
153+
nested.commit(() -> transaction.commit(sync::countDown, err), err);
154+
}, err),
155+
err),
156+
err);
157+
158+
assertTrue(sync.await(5, TimeUnit.SECONDS));
159+
assertEquals(1L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 19").size());
160+
}
161+
162+
@Test
163+
public void shouldRollbackNestedTransaction() throws Exception {
164+
CountDownLatch sync = new CountDownLatch(1);
165+
166+
dbr.db().begin((transaction) ->
167+
transaction.query("INSERT INTO TX_TEST(ID) VALUES(24)", result -> {
168+
assertEquals(1, result.updatedRows());
169+
transaction.begin((nested) ->
170+
nested.query("INSERT INTO TX_TEST(ID) VALUES(23)", res2 -> {
171+
assertEquals(1, res2.updatedRows());
172+
nested.rollback(() -> transaction.commit(sync::countDown, err), err);
173+
}, err), err);
174+
}, err),
175+
err);
176+
177+
assertTrue(sync.await(5, TimeUnit.SECONDS));
178+
assertEquals(1L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 24").size());
179+
assertEquals(0L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 23").size());
180+
}
181+
182+
@Test
183+
public void shouldRollbackNestedTransactionOnBackendError() throws Exception {
184+
CountDownLatch sync = new CountDownLatch(1);
185+
186+
dbr.db().begin((transaction) ->
187+
transaction.query("INSERT INTO TX_TEST(ID) VALUES(25)", result -> {
188+
assertEquals(1, result.updatedRows());
189+
transaction.begin((nested) ->
190+
nested.query("INSERT INTO TX_TEST(ID) VALUES(26)", res2 -> {
191+
assertEquals(1, res2.updatedRows());
192+
nested.query("INSERT INTO TD_TEST(ID) VALUES(26)",
193+
fail, t -> transaction.commit(sync::countDown, err));
194+
}, err), err);
195+
}, err),
196+
err);
197+
198+
assertTrue(sync.await(5, TimeUnit.SECONDS));
199+
assertEquals(1L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 25").size());
200+
assertEquals(0L, dbr.query("SELECT ID FROM TX_TEST WHERE ID = 26").size());
201+
}
145202
}

0 commit comments

Comments
 (0)