|
1 | 1 | package com.github.michaelbull.jdbc |
2 | 2 |
|
3 | | -import com.github.michaelbull.jdbc.context.CoroutineConnection |
4 | 3 | import com.github.michaelbull.jdbc.context.CoroutineTransaction |
5 | 4 | import com.github.michaelbull.jdbc.context.connection |
6 | | -import com.github.michaelbull.jdbc.context.transaction |
7 | 5 | import kotlinx.coroutines.CoroutineScope |
8 | 6 | import kotlinx.coroutines.withContext |
9 | 7 | import java.sql.Connection |
10 | 8 | import kotlin.contracts.InvocationKind |
11 | 9 | import kotlin.contracts.contract |
12 | | -import kotlin.coroutines.CoroutineContext |
13 | 10 | import kotlin.coroutines.coroutineContext |
14 | 11 |
|
15 | 12 | /** |
16 | 13 | * Calls the specified suspending [block] in the context of a [CoroutineTransaction], suspends until it completes, and |
17 | 14 | * returns the result. |
18 | 15 | * |
19 | | - * When there exists a [CoroutineTransaction] in the current [CoroutineContext], the [block] will be immediately invoked |
20 | | - * if the [transaction is running][CoroutineTransaction.isRunning], otherwise an [IllegalStateException] will be thrown. |
| 16 | + * When the [coroutineContext] has no [CoroutineTransaction], the specified suspending [block] will be |
| 17 | + * [ran transactionally][runTransactionally] [with the context of a connection][withConnection]. |
21 | 18 | * |
22 | | - * When no [CoroutineTransaction] exists in the current [CoroutineContext], the [block] will be invoked |
23 | | - * [with the context][withContext] of a new [CoroutineTransaction]. |
| 19 | + * When the [coroutineContext] has an [incomplete][CoroutineTransaction.incomplete] [CoroutineTransaction], the |
| 20 | + * specified suspending [block] will be called [with this context][withContext]. |
24 | 21 | * |
25 | | - * The [block] will be invoked [with a connection][withConnection] in its [CoroutineContext]. The connection's |
26 | | - * [autoCommit][Connection.setAutoCommit] is set to `false` before the invocation. If the [block] throws a [Throwable], |
27 | | - * the transaction will [rollback][Connection.rollback] and re-throw the [Throwable], otherwise the transaction will |
28 | | - * [commit][Connection.commit] and return the result of type [T]. |
| 22 | + * When the [coroutineContext] has a [completed][CoroutineTransaction.completed] [CoroutineTransaction], an |
| 23 | + * [IllegalStateException] will be thrown. |
29 | 24 | */ |
30 | 25 | suspend inline fun <T> transaction(crossinline block: suspend CoroutineScope.() -> T): T { |
31 | 26 | contract { |
32 | 27 | callsInPlace(block, InvocationKind.AT_MOST_ONCE) |
33 | 28 | } |
34 | 29 |
|
35 | | - val existingTransaction = coroutineContext.transaction |
| 30 | + val existingTransaction = coroutineContext[CoroutineTransaction] |
36 | 31 |
|
37 | 32 | return when { |
38 | | - existingTransaction == null -> withContext(CoroutineTransaction()) { |
| 33 | + existingTransaction == null -> { |
39 | 34 | withConnection { |
40 | | - execute(block) |
| 35 | + runTransactionally { |
| 36 | + block() |
| 37 | + } |
41 | 38 | } |
42 | 39 | } |
43 | 40 |
|
44 | | - existingTransaction.isRunning -> withContext(coroutineContext) { |
45 | | - block() |
| 41 | + existingTransaction.incomplete -> { |
| 42 | + withContext(coroutineContext) { |
| 43 | + block() |
| 44 | + } |
46 | 45 | } |
47 | 46 |
|
48 | 47 | else -> error("Attempted to start new transaction within: $existingTransaction") |
49 | 48 | } |
50 | 49 | } |
51 | 50 |
|
52 | 51 | /** |
53 | | - * [Starts][CoroutineTransaction.start] the current [CoroutineTransaction] and sets the |
54 | | - * current [CoroutineConnection]'s [autoCommit][Connection.setAutoCommit] to `false`, calls the specified suspending |
55 | | - * [block], suspends until it completes, then [commits][Connection.commit] and returns the result. |
| 52 | + * Calls the specified suspending [block] [with the context][withContext] of a [CoroutineTransaction] and returns its |
| 53 | + * result. |
| 54 | + * |
| 55 | + * If invocation of the suspending [block] was successful, [commit][Connection.commit] is then called on the |
| 56 | + * [Connection] in the [coroutineContext]. |
56 | 57 | * |
57 | | - * If the [block] throws a [Throwable], the connection will [rollback][Connection.rollback] and not |
58 | | - * [commit][Connection.commit]. |
| 58 | + * If invocation of the suspending [block] throws a [Throwable] exception, [rollback][Connection.rollback] is then |
| 59 | + * called on the [Connection] in the [coroutineContext] and the exception is thrown. |
59 | 60 | */ |
60 | 61 | @PublishedApi |
61 | | -internal suspend inline fun <T> execute(crossinline block: suspend CoroutineScope.() -> T): T { |
| 62 | +internal suspend inline fun <T> runTransactionally(crossinline block: suspend CoroutineScope.() -> T): T { |
62 | 63 | contract { |
63 | | - callsInPlace(block, InvocationKind.AT_MOST_ONCE) |
| 64 | + callsInPlace(block, InvocationKind.EXACTLY_ONCE) |
64 | 65 | } |
65 | 66 |
|
66 | | - val transaction = coroutineContext.transaction ?: error("No transaction in context") |
67 | | - transaction.start() |
| 67 | + coroutineContext.connection.runWithManualCommit { |
| 68 | + val transaction = CoroutineTransaction() |
| 69 | + |
| 70 | + try { |
| 71 | + val result = withContext(transaction) { |
| 72 | + block() |
| 73 | + } |
| 74 | + |
| 75 | + commit() |
| 76 | + return result |
| 77 | + } catch (ex: Throwable) { |
| 78 | + rollback() |
| 79 | + throw ex |
| 80 | + } finally { |
| 81 | + transaction.complete() |
| 82 | + } |
| 83 | + } |
| 84 | +} |
| 85 | + |
| 86 | +/** |
| 87 | + * Disables [autoCommit][Connection.getAutoCommit] mode on `this` [Connection], then calls a specific function [block] |
| 88 | + * with `this` [Connection] as its receiver and returns its result, then sets the [autoCommit][Connection.getAutoCommit] |
| 89 | + * mode on `this` [Connection] back to its original value. |
| 90 | + */ |
| 91 | +@PublishedApi |
| 92 | +internal inline fun <T> Connection.runWithManualCommit(block: Connection.() -> T): T { |
| 93 | + contract { |
| 94 | + callsInPlace(block, InvocationKind.EXACTLY_ONCE) |
| 95 | + } |
68 | 96 |
|
69 | | - val connection = coroutineContext.connection |
70 | | - connection.autoCommit = false |
| 97 | + val before = autoCommit |
71 | 98 |
|
72 | | - try { |
73 | | - val result = withContext(coroutineContext) { block() } |
74 | | - transaction.complete() |
75 | | - connection.commit() |
76 | | - return result |
77 | | - } catch (ex: Throwable) { |
78 | | - transaction.complete() |
79 | | - connection.rollback() |
80 | | - throw ex |
| 99 | + return try { |
| 100 | + autoCommit = false |
| 101 | + this.run(block) |
| 102 | + } finally { |
| 103 | + autoCommit = before |
81 | 104 | } |
82 | 105 | } |
0 commit comments