Skip to content

Commit 366e88c

Browse files
authored
SWIFT-958 Ensure nil is returned from next() before ClosedCursorError (#521)
1 parent b9e8cba commit 366e88c

File tree

7 files changed

+99
-22
lines changed

7 files changed

+99
-22
lines changed

Sources/MongoSwift/CursorCommon.swift

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
9898
/// Indicates that the cursor is still open. Stores a `MongocCursorWrapper`, along with the source
9999
/// connection, and possibly session to ensure they are kept alive as long as the cursor is.
100100
case open(cursor: CursorKind, connection: Connection, session: ClientSession?)
101+
102+
/// Indicates the cursor last returned an error from `next`. `nil` will be returned if
103+
/// `next` is called again, and the cursor will be moved to `closed` state.
104+
case error
105+
106+
/// Indicates the cursor has completed iteration, either by exhausting its results or by returning
107+
/// an error and then `nil`.
101108
case closed
102109
}
103110

@@ -186,13 +193,13 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
186193

187194
guard mongocCursor.next(outPtr: out) else {
188195
if let error = self.getMongocError() {
189-
self.close()
196+
self.close(fromError: true)
190197
throw error
191198
}
192199

193200
// if we've reached the end of the cursor, close it.
194201
if !self.type.isTailable || !mongocCursor.more() {
195-
self.close()
202+
self.close(fromError: false)
196203
}
197204

198205
return nil
@@ -209,12 +216,17 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
209216
/// Close this cursor
210217
///
211218
/// This method should only be called while the lock is held.
212-
private func close() {
219+
private func close(fromError: Bool) {
213220
guard case let .open(mongocCursor, _, _) = self.state else {
214221
return
215222
}
216223
mongocCursor.destroy()
217-
self.state = .closed
224+
225+
if fromError {
226+
self.state = .error
227+
} else {
228+
self.state = .closed
229+
}
218230
}
219231

220232
/// This initializer is blocking and should only be run via the executor.
@@ -232,7 +244,7 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
232244

233245
// If there was an error constructing the cursor, throw it.
234246
if let error = self.getMongocError() {
235-
self.close()
247+
self.close(fromError: true)
236248
throw error
237249
}
238250

@@ -265,7 +277,7 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
265277
switch self.state {
266278
case .open:
267279
return true
268-
case .closed:
280+
case .closed, .error:
269281
return false
270282
}
271283
}
@@ -275,7 +287,11 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
275287
internal func next() throws -> BSONDocument? {
276288
try self.lock.withLock {
277289
guard self._isAlive else {
278-
throw ClosedCursorError
290+
guard case .error = self.state else {
291+
throw ClosedCursorError
292+
}
293+
self.state = .closed
294+
return nil
279295
}
280296

281297
if case let .cached(result) = self.cached.clear() {
@@ -339,7 +355,7 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
339355
self.isClosing.store(true)
340356
self.lock.withLock {
341357
self.cached = .none
342-
self.close()
358+
self.close(fromError: false)
343359
}
344360
self.isClosing.store(false)
345361
}

Sources/TestsCommon/CommonTestUtils.swift

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,17 @@ public struct TestRequirement: Decodable {
158158
private let maxServerVersion: ServerVersion?
159159
private let topology: [TestTopologyConfiguration]?
160160

161+
public static let failCommandSupport: [TestRequirement] = [
162+
TestRequirement(
163+
minServerVersion: ServerVersion.mongodFailCommandSupport,
164+
acceptableTopologies: [.single, .replicaSet]
165+
),
166+
TestRequirement(
167+
minServerVersion: ServerVersion.mongosFailCommandSupport,
168+
acceptableTopologies: [.sharded]
169+
)
170+
]
171+
161172
public init(
162173
minServerVersion: ServerVersion? = nil,
163174
maxServerVersion: ServerVersion? = nil,
@@ -284,23 +295,25 @@ public func sortedEqual(_ expectedValue: BSONDocument?) -> Predicate<BSONDocumen
284295
}
285296
}
286297

298+
public func printSkipMessage(testName: String, reason: String) {
299+
print("Skipping test case \"\(testName)\": \(reason)")
300+
}
301+
287302
/// Prints a message if a server version or topology requirement is not met and a test is skipped
288303
public func printSkipMessage(
289304
testName: String,
290305
unmetRequirement: UnmetRequirement
291306
) {
307+
let reason: String
292308
switch unmetRequirement {
293309
case let .minServerVersion(actual, required):
294-
print("Skipping test case \"\(testName)\": minimum required server " +
295-
"version \(required) not met by current server version \(actual)")
296-
310+
reason = "minimum required server version \(required) not met by current server version \(actual)"
297311
case let .maxServerVersion(actual, required):
298-
print("Skipping test case \"\(testName)\": maximum required server " +
299-
"version \(required) not met by current server version \(actual)")
300-
312+
reason = "maximum required server version \(required) not met by current server version \(actual)"
301313
case let .topology(actual, required):
302-
print("Skipping \(testName) due to unsupported topology type \(actual), supported topologies are: \(required)")
314+
reason = "unsupported topology type \(actual), supported topologies are: \(required)"
303315
}
316+
printSkipMessage(testName: testName, reason: reason)
304317
}
305318

306319
public func unsupportedTopologyMessage(

Sources/TestsCommon/ServerVersion.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ import Foundation
22

33
/// A struct representing a server version.
44
public struct ServerVersion: Comparable, Decodable, CustomStringConvertible {
5+
public static let mongodFailCommandSupport = ServerVersion(major: 4, minor: 0)
6+
public static let mongosFailCommandSupport = ServerVersion(major: 4, minor: 1, patch: 5)
7+
58
let major: Int
69
let minor: Int
710
let patch: Int

Tests/LinuxMain.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ extension MongoCursorTests {
263263
("testKill", testKill),
264264
("testKillTailable", testKillTailable),
265265
("testLazySequence", testLazySequence),
266+
("testCursorTerminatesOnError", testCursorTerminatesOnError),
267+
("testCursorClosedError", testCursorClosedError),
266268
]
267269
}
268270

Tests/MongoSwiftSyncTests/MongoCursorTests.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ final class MongoCursorTests: MongoSwiftTestCase {
106106
expect(try cursor.next()?.get()).to(throwError(expectedError))
107107
// cursor should be closed now that it errored
108108
expect(cursor.isAlive()).to(beFalse())
109+
expect(cursor.next()).to(beNil())
109110

110111
// iterating dead cursor should error
111112
expect(try cursor.next()?.get()).to(throwError(errorType: MongoError.LogicError.self))
@@ -238,4 +239,45 @@ final class MongoCursorTests: MongoSwiftTestCase {
238239
expect(cursor.isAlive()).to(beFalse())
239240
}
240241
}
242+
243+
func testCursorTerminatesOnError() throws {
244+
try self.withTestNamespace { client, _, coll in
245+
guard try client.supportsFailCommand() else {
246+
printSkipMessage(testName: self.name, reason: "failCommand not supported")
247+
return
248+
}
249+
250+
try coll.insertOne([:])
251+
try coll.insertOne([:])
252+
253+
let cursor = try coll.find([:], options: FindOptions(batchSize: 1))
254+
255+
let fp = FailPoint.failCommand(failCommands: ["getMore"], mode: .times(1), errorCode: 10)
256+
try fp.enable()
257+
defer { fp.disable() }
258+
259+
var count = 0
260+
for result in cursor {
261+
expect(count).to(beLessThan(2))
262+
if count >= 2 {
263+
break
264+
}
265+
// getmore should return error
266+
if count == 1 {
267+
expect(try result.get()).to(throwError())
268+
if result.isSuccess { break }
269+
}
270+
count += 1
271+
}
272+
}
273+
}
274+
275+
func testCursorClosedError() throws {
276+
try self.withTestNamespace { _, _, coll in
277+
let cursor = try coll.find([:], options: FindOptions(batchSize: 1))
278+
279+
for _ in cursor {}
280+
expect(try cursor.next()?.get()).to(throwError(errorType: MongoError.LogicError.self))
281+
}
282+
}
241283
}

Tests/MongoSwiftSyncTests/SyncTestUtils.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ extension MongoClient {
7171
return testRequirement.getUnmetRequirement(givenCurrent: serverVersion, topologyType)
7272
}
7373

74+
internal func meetsAnyRequirement(in requirements: [TestRequirement]) throws -> Bool {
75+
try requirements.contains {
76+
try self.getUnmetRequirement($0) == nil
77+
}
78+
}
79+
7480
/// Get the max wire version of the primary.
7581
internal func maxWireVersion() throws -> Int {
7682
let options = RunCommandOptions(readPreference: .primary)
@@ -114,13 +120,7 @@ extension MongoClient {
114120
}
115121

116122
internal func supportsFailCommand() throws -> Bool {
117-
let version = try self.serverVersion()
118-
switch MongoSwiftTestCase.topologyType {
119-
case .sharded:
120-
return version >= ServerVersion(major: 4, minor: 1, patch: 5)
121-
default:
122-
return version >= ServerVersion(major: 4, minor: 0)
123-
}
123+
try self.meetsAnyRequirement(in: TestRequirement.failCommandSupport)
124124
}
125125
}
126126

Tests/MongoSwiftTests/MongoCursorTests.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ final class AsyncMongoCursorTests: MongoSwiftTestCase {
129129
expect(try cursor.next().wait()).to(throwError(expectedError))
130130
// cursor should be closed now that it errored
131131
expect(try cursor.isAlive().wait()).to(beFalse())
132+
expect(try cursor.next().wait()).to(beNil())
132133

133134
// iterating dead cursor should error
134135
expect(try cursor.next().wait()).to(throwError(errorType: MongoError.LogicError.self))

0 commit comments

Comments
 (0)