Skip to content

RUST-1048 add default_database api for client #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ impl Client {
Database::new(self.clone(), name, Some(options))
}

/// Gets a handle to the default database specified in the `ClientOptions` or MongoDB connection
/// string used to construct this `Client`.
///
/// If no default database was specified, `None` will be returned.
pub fn default_database(&self) -> Option<Database> {
self.inner
.options
.default_database
.as_ref()
.map(|db_name| self.database(db_name))
}

async fn list_databases_common(
&self,
filter: impl Into<Option<Document>>,
Expand Down
16 changes: 15 additions & 1 deletion src/client/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,12 @@ pub struct ClientOptions {
#[builder(default)]
pub server_selection_timeout: Option<Duration>,

/// Default database for this client.
///
/// By default, no default database is specified.
#[builder(default)]
pub default_database: Option<String>,

#[builder(default, setter(skip))]
pub(crate) socket_timeout: Option<Duration>,

Expand Down Expand Up @@ -702,6 +708,7 @@ struct ClientOptionsParser {
pub zlib_compression: Option<i32>,
pub direct_connection: Option<bool>,
pub credential: Option<Credential>,
pub default_database: Option<String>,
max_staleness: Option<Duration>,
tls_insecure: Option<bool>,
auth_mechanism: Option<AuthMechanism>,
Expand Down Expand Up @@ -931,6 +938,7 @@ impl From<ClientOptionsParser> for ClientOptions {
retry_writes: parser.retry_writes,
socket_timeout: parser.socket_timeout,
direct_connection: parser.direct_connection,
default_database: parser.default_database,
driver_info: None,
credential: parser.credential,
cmap_event_handler: None,
Expand Down Expand Up @@ -969,6 +977,9 @@ impl ClientOptions {
///
/// The format of a MongoDB connection string is described [here](https://docs.mongodb.com/manual/reference/connection-string/#connection-string-formats).
///
/// Note that [default_database](ClientOptions::default_database) will be set from
/// `/defaultauthdb` in connection string.
///
/// The following options are supported in the options query string:
///
/// * `appName`: maps to the `app_name` field
Expand Down Expand Up @@ -1468,7 +1479,7 @@ impl ClientOptionsParser {
credential.source = options
.auth_source
.clone()
.or(db)
.or(db.clone())
.or_else(|| Some("admin".into()));
} else if authentication_requested {
return Err(ErrorKind::InvalidArgument {
Expand All @@ -1481,6 +1492,9 @@ impl ClientOptionsParser {
}
};

// set default database.
options.default_database = db;

if options.tls.is_none() && options.srv {
options.tls = Some(Tls::Enabled(Default::default()));
}
Expand Down
38 changes: 38 additions & 0 deletions src/client/options/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,41 @@ async fn parse_unknown_options() {
.await;
parse_uri("maxstalenessms", Some("maxstalenessseconds")).await;
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn parse_with_default_database() {
let uri = "mongodb://localhost/abc";

assert_eq!(
ClientOptions::parse(uri).await.unwrap(),
ClientOptions {
hosts: vec![ServerAddress::Tcp {
host: "localhost".to_string(),
port: None
}],
original_uri: Some(uri.into()),
default_database: Some("abc".to_string()),
..Default::default()
}
);
}

#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
async fn parse_with_no_default_database() {
let uri = "mongodb://localhost/";

assert_eq!(
ClientOptions::parse(uri).await.unwrap(),
ClientOptions {
hosts: vec![ServerAddress::Tcp {
host: "localhost".to_string(),
port: None
}],
original_uri: Some(uri.into()),
default_database: None,
..Default::default()
}
);
}
8 changes: 8 additions & 0 deletions src/sync/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ impl Client {
Database::new(self.async_client.database_with_options(name, options))
}

/// Gets a handle to the default database specified in the `ClientOptions` or MongoDB connection
/// string used to construct this `Client`.
///
/// If no default database was specified, `None` will be returned.
pub fn default_database(&self) -> Option<Database> {
self.async_client.default_database().map(Database::new)
}

/// Gets information about each database present in the cluster the Client is connected to.
pub fn list_databases(
&self,
Expand Down
30 changes: 30 additions & 0 deletions src/sync/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,36 @@ fn client() {
assert!(db_names.contains(&function_name!().to_string()));
}

#[test]
#[function_name::named]
fn default_database() {
// here we just test default database name matched, the database interactive logic
// is tested in `database`.
let _guard: RwLockReadGuard<()> = RUNTIME.block_on(async { LOCK.run_concurrently().await });

let options = CLIENT_OPTIONS.clone();
let client = Client::with_options(options).expect("client creation should succeed");
let default_db = client.default_database();
assert!(default_db.is_none());

// create client througth options.
let mut options = CLIENT_OPTIONS.clone();
options.default_database = Some("abcd".to_string());
let client = Client::with_options(options).expect("client creation should succeed");
let default_db = client
.default_database()
.expect("should have a default database.");
assert_eq!(default_db.name(), "abcd");

// create client directly through uri_str.
let client = Client::with_uri_str("mongodb://localhost:27017/abcd")
.expect("client creation should succeed");
let default_db = client
.default_database()
.expect("should have a default database.");
assert_eq!(default_db.name(), "abcd");
}

#[test]
#[function_name::named]
fn database() {
Expand Down