Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions Sources/MongoSwift/CursorCommon.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
/// Indicates that the cursor is still open. Stores a `MongocCursorWrapper`, along with the source
/// connection, and possibly session to ensure they are kept alive as long as the cursor is.
case open(cursor: CursorKind, connection: Connection, session: ClientSession?)

/// Indicates the cursor last returned an error from `next`. `nil` will be returned if
/// `next` is called again, and the cursor will be moved to `closed` state.
case error

/// Indicates the cursor has completed iteration, either by exhausting its results or by returning
/// an error and then `nil`.
case closed
}

Expand Down Expand Up @@ -186,13 +193,13 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {

guard mongocCursor.next(outPtr: out) else {
if let error = self.getMongocError() {
self.close()
self.close(fromError: true)
throw error
}

// if we've reached the end of the cursor, close it.
if !self.type.isTailable || !mongocCursor.more() {
self.close()
self.close(fromError: false)
}

return nil
Expand All @@ -209,12 +216,17 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
/// Close this cursor
///
/// This method should only be called while the lock is held.
private func close() {
private func close(fromError: Bool) {
guard case let .open(mongocCursor, _, _) = self.state else {
return
}
mongocCursor.destroy()
self.state = .closed

if fromError {
self.state = .error
} else {
self.state = .closed
}
}

/// This initializer is blocking and should only be run via the executor.
Expand All @@ -232,7 +244,7 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {

// If there was an error constructing the cursor, throw it.
if let error = self.getMongocError() {
self.close()
self.close(fromError: true)
throw error
}

Expand Down Expand Up @@ -265,7 +277,7 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
switch self.state {
case .open:
return true
case .closed:
case .closed, .error:
return false
}
}
Expand All @@ -275,7 +287,11 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
internal func next() throws -> BSONDocument? {
try self.lock.withLock {
guard self._isAlive else {
throw ClosedCursorError
guard case .error = self.state else {
throw ClosedCursorError
}
self.state = .closed
return nil
}

if case let .cached(result) = self.cached.clear() {
Expand Down Expand Up @@ -339,7 +355,7 @@ internal class Cursor<CursorKind: MongocCursorWrapper> {
self.isClosing.store(true)
self.lock.withLock {
self.cached = .none
self.close()
self.close(fromError: false)
}
self.isClosing.store(false)
}
Expand Down
27 changes: 20 additions & 7 deletions Sources/TestsCommon/CommonTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ public struct TestRequirement: Decodable {
private let maxServerVersion: ServerVersion?
private let topology: [TestTopologyConfiguration]?

public static let failCommandSupport: [TestRequirement] = [
TestRequirement(
minServerVersion: ServerVersion.mongodFailCommandSupport,
acceptableTopologies: [.single, .replicaSet]
),
TestRequirement(
minServerVersion: ServerVersion.mongosFailCommandSupport,
acceptableTopologies: [.sharded]
)
]

public init(
minServerVersion: ServerVersion? = nil,
maxServerVersion: ServerVersion? = nil,
Expand Down Expand Up @@ -284,23 +295,25 @@ public func sortedEqual(_ expectedValue: BSONDocument?) -> Predicate<BSONDocumen
}
}

public func printSkipMessage(testName: String, reason: String) {
print("Skipping test case \"\(testName)\": \(reason)")
}

/// Prints a message if a server version or topology requirement is not met and a test is skipped
public func printSkipMessage(
testName: String,
unmetRequirement: UnmetRequirement
) {
let reason: String
switch unmetRequirement {
case let .minServerVersion(actual, required):
print("Skipping test case \"\(testName)\": minimum required server " +
"version \(required) not met by current server version \(actual)")

reason = "minimum required server version \(required) not met by current server version \(actual)"
case let .maxServerVersion(actual, required):
print("Skipping test case \"\(testName)\": maximum required server " +
"version \(required) not met by current server version \(actual)")

reason = "maximum required server version \(required) not met by current server version \(actual)"
case let .topology(actual, required):
print("Skipping \(testName) due to unsupported topology type \(actual), supported topologies are: \(required)")
reason = "unsupported topology type \(actual), supported topologies are: \(required)"
}
printSkipMessage(testName: testName, reason: reason)
}

public func unsupportedTopologyMessage(
Expand Down
3 changes: 3 additions & 0 deletions Sources/TestsCommon/ServerVersion.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import Foundation

/// A struct representing a server version.
public struct ServerVersion: Comparable, Decodable, CustomStringConvertible {
public static let mongodFailCommandSupport = ServerVersion(major: 4, minor: 0)
public static let mongosFailCommandSupport = ServerVersion(major: 4, minor: 1, patch: 5)

let major: Int
let minor: Int
let patch: Int
Expand Down
2 changes: 2 additions & 0 deletions Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ extension MongoCursorTests {
("testKill", testKill),
("testKillTailable", testKillTailable),
("testLazySequence", testLazySequence),
("testCursorTerminatesOnError", testCursorTerminatesOnError),
("testCursorClosedError", testCursorClosedError),
]
}

Expand Down
42 changes: 42 additions & 0 deletions Tests/MongoSwiftSyncTests/MongoCursorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ final class MongoCursorTests: MongoSwiftTestCase {
expect(try cursor.next()?.get()).to(throwError(expectedError))
// cursor should be closed now that it errored
expect(cursor.isAlive()).to(beFalse())
expect(cursor.next()).to(beNil())

// iterating dead cursor should error
expect(try cursor.next()?.get()).to(throwError(errorType: MongoError.LogicError.self))
Expand Down Expand Up @@ -238,4 +239,45 @@ final class MongoCursorTests: MongoSwiftTestCase {
expect(cursor.isAlive()).to(beFalse())
}
}

func testCursorTerminatesOnError() throws {
try self.withTestNamespace { client, _, coll in
guard try client.supportsFailCommand() else {
printSkipMessage(testName: self.name, reason: "failCommand not supported")
return
}

try coll.insertOne([:])
try coll.insertOne([:])

let cursor = try coll.find([:], options: FindOptions(batchSize: 1))

let fp = FailPoint.failCommand(failCommands: ["getMore"], mode: .times(1), errorCode: 10)
try fp.enable()
defer { fp.disable() }

var count = 0
for result in cursor {
expect(count).to(beLessThan(2))
if count >= 2 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not get these tests to halt after failing an expect, so I manually added breaks. I tried setting self.continueAfterFailure = false, but it didn't seem to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, weird 🤔

break
}
// getmore should return error
if count == 1 {
expect(try result.get()).to(throwError())
if result.isSuccess { break }
}
count += 1
}
}
}

func testCursorClosedError() throws {
try self.withTestNamespace { _, _, coll in
let cursor = try coll.find([:], options: FindOptions(batchSize: 1))

for _ in cursor {}
expect(try cursor.next()?.get()).to(throwError(errorType: MongoError.LogicError.self))
}
}
}
14 changes: 7 additions & 7 deletions Tests/MongoSwiftSyncTests/SyncTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ extension MongoClient {
return testRequirement.getUnmetRequirement(givenCurrent: serverVersion, topologyType)
}

internal func meetsAnyRequirement(in requirements: [TestRequirement]) throws -> Bool {
try requirements.contains {
try self.getUnmetRequirement($0) == nil
}
}

/// Get the max wire version of the primary.
internal func maxWireVersion() throws -> Int {
let options = RunCommandOptions(readPreference: .primary)
Expand Down Expand Up @@ -114,13 +120,7 @@ extension MongoClient {
}

internal func supportsFailCommand() throws -> Bool {
let version = try self.serverVersion()
switch MongoSwiftTestCase.topologyType {
case .sharded:
return version >= ServerVersion(major: 4, minor: 1, patch: 5)
default:
return version >= ServerVersion(major: 4, minor: 0)
}
try self.meetsAnyRequirement(in: TestRequirement.failCommandSupport)
}
}

Expand Down
1 change: 1 addition & 0 deletions Tests/MongoSwiftTests/MongoCursorTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ final class AsyncMongoCursorTests: MongoSwiftTestCase {
expect(try cursor.next().wait()).to(throwError(expectedError))
// cursor should be closed now that it errored
expect(try cursor.isAlive().wait()).to(beFalse())
expect(try cursor.next().wait()).to(beNil())

// iterating dead cursor should error
expect(try cursor.next().wait()).to(throwError(errorType: MongoError.LogicError.self))
Expand Down