Skip to content

Commit

Permalink
Merge pull request blackbeam#236 from blackbeam/capabilities-setup
Browse files Browse the repository at this point in the history
Capabilities setup
  • Loading branch information
blackbeam authored Feb 23, 2023
2 parents 8323780 + 75ade0d commit f1afb7c
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
52 changes: 52 additions & 0 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,58 @@ mod test {
Ok(())
}

#[tokio::test]
async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> {
let opts = get_opts().client_found_rows(true);
let mut conn = Conn::new(opts).await.unwrap();

"CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
.ignore(&mut conn)
.await?;

"INSERT INTO mysql.found_rows (val) VALUES (1)"
.ignore(&mut conn)
.await?;

// Inserted one row, affected should be one.
assert_eq!(conn.affected_rows(), 1);

"UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
.ignore(&mut conn)
.await?;

// The query doesn't affect any rows, but due to us wanting FOUND rows,
// this has to return one.
assert_eq!(conn.affected_rows(), 1);

Ok(())
}

#[tokio::test]
async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> {
let mut conn = Conn::new(get_opts()).await.unwrap();

"CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)"
.ignore(&mut conn)
.await?;

"INSERT INTO mysql.found_rows (val) VALUES (1)"
.ignore(&mut conn)
.await?;

// Inserted one row, affected should be one.
assert_eq!(conn.affected_rows(), 1);

"UPDATE mysql.found_rows SET val = 1 WHERE val = 1"
.ignore(&mut conn)
.await?;

// The query doesn't affect any rows.
assert_eq!(conn.affected_rows(), 0);

Ok(())
}

async fn read_binlog_streams_and_close_their_connections(
pool: Option<&Pool>,
binlog_server_ids: (u32, u32, u32),
Expand Down
48 changes: 48 additions & 0 deletions src/opts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ pub(crate) struct MysqlOpts {
///
/// Available via `secure_auth` connection url parameter.
secure_auth: bool,

/// Enables `CLIENT_FOUND_ROWS` capability (defaults to `false`).
///
/// Changes the behavior of the affected count returned for writes (UPDATE/INSERT etc).
/// It makes MySQL return the FOUND rows instead of the AFFECTED rows.
client_found_rows: bool,
}

/// Mysql connection options.
Expand Down Expand Up @@ -721,6 +727,26 @@ impl Opts {
self.inner.mysql_opts.secure_auth
}

/// Returns `true` if `CLIENT_FOUND_ROWS` capability is enabled (defaults to `false`).
///
/// `CLIENT_FOUND_ROWS` changes the behavior of the affected count returned for writes
/// (UPDATE/INSERT etc). It makes MySQL return the FOUND rows instead of the AFFECTED rows.
///
/// # Connection URL
///
/// Use `client_found_rows` URL parameter to set this value. E.g.
///
/// ```
/// # use mysql_async::*;
/// # fn main() -> Result<()> {
/// let opts = Opts::from_url("mysql://localhost/db?client_found_rows=true")?;
/// assert!(opts.client_found_rows());
/// # Ok(()) }
/// ```
pub fn client_found_rows(&self) -> bool {
self.inner.mysql_opts.client_found_rows
}

pub(crate) fn get_capabilities(&self) -> CapabilityFlags {
let mut out = CapabilityFlags::CLIENT_PROTOCOL_41
| CapabilityFlags::CLIENT_SECURE_CONNECTION
Expand All @@ -742,6 +768,9 @@ impl Opts {
if self.inner.mysql_opts.compression.is_some() {
out |= CapabilityFlags::CLIENT_COMPRESS;
}
if self.client_found_rows() {
out |= CapabilityFlags::CLIENT_FOUND_ROWS;
}

out
}
Expand All @@ -767,6 +796,7 @@ impl Default for MysqlOpts {
max_allowed_packet: None,
wait_timeout: None,
secure_auth: true,
client_found_rows: false,
}
}
}
Expand Down Expand Up @@ -1017,6 +1047,12 @@ impl OptsBuilder {
self.opts.secure_auth = secure_auth;
self
}

/// Enables or disables `CLIENT_FOUND_ROWS` capability. See [`Opts::client_found_rows`].
pub fn client_found_rows(mut self, client_found_rows: bool) -> Self {
self.opts.client_found_rows = client_found_rows;
self
}
}

impl From<OptsBuilder> for Opts {
Expand Down Expand Up @@ -1245,6 +1281,18 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result<MysqlOpts, UrlError> {
});
}
}
} else if key == "client_found_rows" {
match bool::from_str(&*value) {
Ok(client_found_rows) => {
opts.client_found_rows = client_found_rows;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "client_found_rows".into(),
value,
});
}
}
} else if key == "socket" {
opts.socket = Some(value)
} else if key == "compression" {
Expand Down

0 comments on commit f1afb7c

Please sign in to comment.